1use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35 Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{BinaryExpr, Case, Expr, Operator, SortExpr, col};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71 pub fn new() -> Self {
72 Self {}
73 }
74
75 fn try_optimize_proj(
76 &self,
77 projection: Projection,
78 config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 let Projection {
81 expr,
82 input,
83 schema,
84 ..
85 } = projection;
86 let input = Arc::unwrap_or_clone(input);
87 self.try_unary_plan(expr, input, config)?
88 .map_data(|(new_expr, new_input)| {
89 Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90 .map(LogicalPlan::Projection)
91 })
92 }
93
94 fn try_optimize_sort(
95 &self,
96 sort: Sort,
97 config: &dyn OptimizerConfig,
98 ) -> Result<Transformed<LogicalPlan>> {
99 let Sort { expr, input, fetch } = sort;
100 let input = Arc::unwrap_or_clone(input);
101 let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102 .into_iter()
103 .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104 .unzip();
105 let new_sort = self
106 .try_unary_plan(sort_expressions, input, config)?
107 .update_data(|(new_expr, new_input)| {
108 LogicalPlan::Sort(Sort {
109 expr: new_expr
110 .into_iter()
111 .zip(sort_params)
112 .map(|(expr, (asc, nulls_first))| SortExpr {
113 expr,
114 asc,
115 nulls_first,
116 })
117 .collect(),
118 input: Arc::new(new_input),
119 fetch,
120 })
121 });
122 Ok(new_sort)
123 }
124
125 fn try_optimize_filter(
126 &self,
127 filter: Filter,
128 config: &dyn OptimizerConfig,
129 ) -> Result<Transformed<LogicalPlan>> {
130 let Filter {
131 predicate, input, ..
132 } = filter;
133 let input = Arc::unwrap_or_clone(input);
134 let expr = vec![predicate];
135 self.try_unary_plan(expr, input, config)?
136 .map_data(|(mut new_expr, new_input)| {
137 assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
139 Filter::try_new(new_predicate, Arc::new(new_input))
140 .map(LogicalPlan::Filter)
141 })
142 }
143
144 fn try_optimize_window(
145 &self,
146 window: Window,
147 config: &dyn OptimizerConfig,
148 ) -> Result<Transformed<LogicalPlan>> {
149 let (window_expr_list, window_schemas, input) =
152 get_consecutive_window_exprs(window);
153
154 match CSE::new(ExprCSEController::new(
157 config.alias_generator().as_ref(),
158 ExprMask::Normal,
159 ))
160 .extract_common_nodes(window_expr_list)?
161 {
162 FoundCommonNodes::Yes {
166 common_nodes: common_exprs,
167 new_nodes_list: new_exprs_list,
168 original_nodes_list: original_exprs_list,
169 } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170 Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171 }),
172 FoundCommonNodes::No {
173 original_nodes_list: original_exprs_list,
174 } => Ok(Transformed::no((original_exprs_list, input, None))),
175 }?
176 .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179 self.rewrite(new_input, config)?.map_data(|new_input| {
180 Ok((new_window_expr_list, new_input, window_expr_list))
181 })
182 })?
183 .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185 if let Some(window_expr_list) = window_expr_list {
194 let name_preserver = NamePreserver::new_for_projection();
195 let saved_names = window_expr_list
196 .iter()
197 .map(|exprs| {
198 exprs
199 .iter()
200 .map(|expr| name_preserver.save(expr))
201 .collect::<Vec<_>>()
202 })
203 .collect::<Vec<_>>();
204 new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205 new_input,
206 |plan, (new_window_expr, saved_names)| {
207 let new_window_expr = new_window_expr
208 .into_iter()
209 .zip(saved_names)
210 .map(|(new_window_expr, saved_name)| {
211 saved_name.restore(new_window_expr)
212 })
213 .collect::<Vec<_>>();
214 Window::try_new(new_window_expr, Arc::new(plan))
215 .map(LogicalPlan::Window)
216 },
217 )
218 } else {
219 new_window_expr_list
220 .into_iter()
221 .zip(window_schemas)
222 .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223 Window::try_new_with_schema(
224 new_window_expr,
225 Arc::new(plan),
226 schema,
227 )
228 .map(LogicalPlan::Window)
229 })
230 }
231 })
232 }
233
234 fn try_optimize_aggregate(
235 &self,
236 aggregate: Aggregate,
237 config: &dyn OptimizerConfig,
238 ) -> Result<Transformed<LogicalPlan>> {
239 let Aggregate {
240 group_expr,
241 aggr_expr,
242 input,
243 schema,
244 ..
245 } = aggregate;
246 let input = Arc::unwrap_or_clone(input);
247 match CSE::new(ExprCSEController::new(
249 config.alias_generator().as_ref(),
250 ExprMask::Normal,
251 ))
252 .extract_common_nodes(vec![group_expr, aggr_expr])?
253 {
254 FoundCommonNodes::Yes {
258 common_nodes: common_exprs,
259 new_nodes_list: mut new_exprs_list,
260 original_nodes_list: mut original_exprs_list,
261 } => {
262 let new_aggr_expr = new_exprs_list.pop().unwrap();
263 let new_group_expr = new_exprs_list.pop().unwrap();
264
265 build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266 let aggr_expr = original_exprs_list.pop().unwrap();
267 Transformed::yes((
268 new_aggr_expr,
269 new_group_expr,
270 new_input,
271 Some(aggr_expr),
272 ))
273 })
274 }
275
276 FoundCommonNodes::No {
277 original_nodes_list: mut original_exprs_list,
278 } => {
279 let new_aggr_expr = original_exprs_list.pop().unwrap();
280 let new_group_expr = original_exprs_list.pop().unwrap();
281
282 Ok(Transformed::no((
283 new_aggr_expr,
284 new_group_expr,
285 input,
286 None,
287 )))
288 }
289 }?
290 .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293 self.rewrite(new_input, config)?.map_data(|new_input| {
294 Ok((
295 new_aggr_expr,
296 new_group_expr,
297 aggr_expr,
298 Arc::new(new_input),
299 ))
300 })
301 })?
302 .transform_data(
304 |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305 match CSE::new(ExprCSEController::new(
307 config.alias_generator().as_ref(),
308 ExprMask::NormalAndAggregates,
309 ))
310 .extract_common_nodes(vec![new_aggr_expr])?
311 {
312 FoundCommonNodes::Yes {
313 common_nodes: common_exprs,
314 new_nodes_list: mut new_exprs_list,
315 original_nodes_list: mut original_exprs_list,
316 } => {
317 let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318 let new_aggr_expr = original_exprs_list.pop().unwrap();
319 let saved_names = if let Some(aggr_expr) = aggr_expr {
320 let name_preserver = NamePreserver::new_for_projection();
321 aggr_expr
322 .iter()
323 .map(|expr| Some(name_preserver.save(expr)))
324 .collect::<Vec<_>>()
325 } else {
326 new_aggr_expr
327 .clone()
328 .into_iter()
329 .map(|_| None)
330 .collect::<Vec<_>>()
331 };
332
333 let mut agg_exprs = common_exprs
334 .into_iter()
335 .map(|(expr, expr_alias)| expr.alias(expr_alias))
336 .collect::<Vec<_>>();
337
338 let mut proj_exprs = vec![];
339 for expr in &new_group_expr {
340 extract_expressions(expr, &mut proj_exprs)
341 }
342 for ((expr_rewritten, expr_orig), saved_name) in
343 rewritten_aggr_expr
344 .into_iter()
345 .zip(new_aggr_expr)
346 .zip(saved_names)
347 {
348 if expr_rewritten == expr_orig {
349 let expr_rewritten = if let Some(saved_name) = saved_name
350 {
351 saved_name.restore(expr_rewritten)
352 } else {
353 expr_rewritten
354 };
355 if let Expr::Alias(Alias { expr, name, .. }) =
356 expr_rewritten
357 {
358 agg_exprs.push(expr.alias(&name));
359 proj_exprs
360 .push(Expr::Column(Column::from_name(name)));
361 } else {
362 let expr_alias =
363 config.alias_generator().next(CSE_PREFIX);
364 let (qualifier, field_name) =
365 expr_rewritten.qualified_name();
366 let out_name =
367 qualified_name(qualifier.as_ref(), &field_name);
368
369 agg_exprs.push(expr_rewritten.alias(&expr_alias));
370 proj_exprs.push(
371 Expr::Column(Column::from_name(expr_alias))
372 .alias(out_name),
373 );
374 }
375 } else {
376 proj_exprs.push(expr_rewritten);
377 }
378 }
379
380 let agg = LogicalPlan::Aggregate(Aggregate::try_new(
381 new_input,
382 new_group_expr,
383 agg_exprs,
384 )?);
385 Projection::try_new(proj_exprs, Arc::new(agg))
386 .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
387 }
388
389 FoundCommonNodes::No {
392 original_nodes_list: mut original_exprs_list,
393 } => {
394 let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
395
396 if let Some(aggr_expr) = aggr_expr {
407 let name_preserver = NamePreserver::new_for_projection();
408 let saved_names = aggr_expr
409 .iter()
410 .map(|expr| name_preserver.save(expr))
411 .collect::<Vec<_>>();
412 let new_aggr_expr = rewritten_aggr_expr
413 .into_iter()
414 .zip(saved_names)
415 .map(|(new_expr, saved_name)| {
416 saved_name.restore(new_expr)
417 })
418 .collect::<Vec<Expr>>();
419
420 Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
423 .map(LogicalPlan::Aggregate)
424 .map(Transformed::no)
425 } else {
426 Aggregate::try_new_with_schema(
427 new_input,
428 new_group_expr,
429 rewritten_aggr_expr,
430 schema,
431 )
432 .map(LogicalPlan::Aggregate)
433 .map(Transformed::no)
434 }
435 }
436 }
437 },
438 )
439 }
440
441 fn try_unary_plan(
456 &self,
457 exprs: Vec<Expr>,
458 input: LogicalPlan,
459 config: &dyn OptimizerConfig,
460 ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
461 match CSE::new(ExprCSEController::new(
463 config.alias_generator().as_ref(),
464 ExprMask::Normal,
465 ))
466 .extract_common_nodes(vec![exprs])?
467 {
468 FoundCommonNodes::Yes {
469 common_nodes: common_exprs,
470 new_nodes_list: mut new_exprs_list,
471 original_nodes_list: _,
472 } => {
473 let new_exprs = new_exprs_list.pop().unwrap();
474 build_common_expr_project_plan(input, common_exprs)
475 .map(|new_input| Transformed::yes((new_exprs, new_input)))
476 }
477 FoundCommonNodes::No {
478 original_nodes_list: mut original_exprs_list,
479 } => {
480 let new_exprs = original_exprs_list.pop().unwrap();
481 Ok(Transformed::no((new_exprs, input)))
482 }
483 }?
484 .transform_data(|(new_exprs, new_input)| {
487 self.rewrite(new_input, config)?
488 .map_data(|new_input| Ok((new_exprs, new_input)))
489 })
490 }
491}
492
493fn get_consecutive_window_exprs(
525 window: Window,
526) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
527 let mut window_expr_list = vec![];
528 let mut window_schemas = vec![];
529 let mut plan = LogicalPlan::Window(window);
530 while let LogicalPlan::Window(Window {
531 input,
532 window_expr,
533 schema,
534 }) = plan
535 {
536 window_expr_list.push(window_expr);
537 window_schemas.push(schema);
538
539 plan = Arc::unwrap_or_clone(input);
540 }
541 (window_expr_list, window_schemas, plan)
542}
543
544impl OptimizerRule for CommonSubexprEliminate {
545 fn supports_rewrite(&self) -> bool {
546 true
547 }
548
549 fn apply_order(&self) -> Option<ApplyOrder> {
550 None
554 }
555
556 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
557 fn rewrite(
558 &self,
559 plan: LogicalPlan,
560 config: &dyn OptimizerConfig,
561 ) -> Result<Transformed<LogicalPlan>> {
562 let original_schema = Arc::clone(plan.schema());
563
564 let optimized_plan = match plan {
565 LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
566 LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
567 LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
568 LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
569 LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
570 LogicalPlan::Join(_)
571 | LogicalPlan::Repartition(_)
572 | LogicalPlan::Union(_)
573 | LogicalPlan::TableScan(_)
574 | LogicalPlan::Values(_)
575 | LogicalPlan::EmptyRelation(_)
576 | LogicalPlan::Subquery(_)
577 | LogicalPlan::SubqueryAlias(_)
578 | LogicalPlan::Limit(_)
579 | LogicalPlan::Ddl(_)
580 | LogicalPlan::Explain(_)
581 | LogicalPlan::Analyze(_)
582 | LogicalPlan::Statement(_)
583 | LogicalPlan::DescribeTable(_)
584 | LogicalPlan::Distinct(_)
585 | LogicalPlan::Extension(_)
586 | LogicalPlan::Dml(_)
587 | LogicalPlan::Copy(_)
588 | LogicalPlan::Unnest(_)
589 | LogicalPlan::RecursiveQuery(_) => {
590 plan.map_children(|c| self.rewrite(c, config))?
593 }
594 };
595
596 if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
598 {
599 optimized_plan.map_data(|optimized_plan| {
600 build_recover_project_plan(&original_schema, optimized_plan)
601 })
602 } else {
603 Ok(optimized_plan)
604 }
605 }
606
607 fn name(&self) -> &str {
608 "common_sub_expression_eliminate"
609 }
610}
611
612#[derive(Debug, Clone, Copy)]
614enum ExprMask {
615 Normal,
624
625 NormalAndAggregates,
627}
628
629struct ExprCSEController<'a> {
630 alias_generator: &'a AliasGenerator,
631 mask: ExprMask,
632
633 alias_counter: usize,
635}
636
637impl<'a> ExprCSEController<'a> {
638 fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
639 Self {
640 alias_generator,
641 mask,
642 alias_counter: 0,
643 }
644 }
645}
646
647impl CSEController for ExprCSEController<'_> {
648 type Node = Expr;
649
650 fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
651 match node {
652 Expr::ScalarFunction(ScalarFunction { func, args }) => {
656 func.conditional_arguments(args)
657 }
658
659 Expr::BinaryExpr(BinaryExpr {
662 left,
663 op: Operator::And | Operator::Or,
664 right,
665 }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
666
667 Expr::Case(Case {
671 expr,
672 when_then_expr,
673 else_expr,
674 }) => Some((
675 expr.iter()
676 .map(|e| e.as_ref())
677 .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
678 .collect(),
679 when_then_expr
680 .iter()
681 .take(1)
682 .map(|(_, then)| then.as_ref())
683 .chain(
684 when_then_expr
685 .iter()
686 .skip(1)
687 .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
688 )
689 .chain(else_expr.iter().map(|e| e.as_ref()))
690 .collect(),
691 )),
692 _ => None,
693 }
694 }
695
696 fn is_valid(node: &Expr) -> bool {
697 !node.is_volatile_node()
698 }
699
700 fn is_ignored(&self, node: &Expr) -> bool {
701 #[expect(deprecated)]
703 let is_normal_minus_aggregates = matches!(
704 node,
705 Expr::Literal(..)
706 | Expr::Column(..)
707 | Expr::ScalarVariable(..)
708 | Expr::Alias(..)
709 | Expr::Wildcard { .. }
710 );
711
712 let is_aggr = matches!(node, Expr::AggregateFunction(..));
713
714 match self.mask {
715 ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
716 ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
717 }
718 }
719
720 fn generate_alias(&self) -> String {
721 self.alias_generator.next(CSE_PREFIX)
722 }
723
724 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
725 if self.alias_counter > 0 {
727 col(alias)
728 } else {
729 self.alias_counter += 1;
730 col(alias).alias(node.schema_name().to_string())
731 }
732 }
733
734 fn rewrite_f_down(&mut self, node: &Expr) {
735 if matches!(node, Expr::Alias(_)) {
736 self.alias_counter += 1;
737 }
738 }
739 fn rewrite_f_up(&mut self, node: &Expr) {
740 if matches!(node, Expr::Alias(_)) {
741 self.alias_counter -= 1
742 }
743 }
744}
745
746impl Default for CommonSubexprEliminate {
747 fn default() -> Self {
748 Self::new()
749 }
750}
751
752fn build_common_expr_project_plan(
763 input: LogicalPlan,
764 common_exprs: Vec<(Expr, String)>,
765) -> Result<LogicalPlan> {
766 let mut fields_set = BTreeSet::new();
767 let mut project_exprs = common_exprs
768 .into_iter()
769 .map(|(expr, expr_alias)| {
770 fields_set.insert(expr_alias.clone());
771 Ok(expr.alias(expr_alias))
772 })
773 .collect::<Result<Vec<_>>>()?;
774
775 for (qualifier, field) in input.schema().iter() {
776 if fields_set.insert(qualified_name(qualifier, field.name())) {
777 project_exprs.push(Expr::from((qualifier, field)));
778 }
779 }
780
781 Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
782}
783
784fn build_recover_project_plan(
790 schema: &DFSchema,
791 input: LogicalPlan,
792) -> Result<LogicalPlan> {
793 let col_exprs = schema.iter().map(Expr::from).collect();
794 Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
795}
796
797fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
798 if let Expr::GroupingSet(groupings) = expr {
799 for e in groupings.distinct_expr() {
800 let (qualifier, field_name) = e.qualified_name();
801 let col = Column::new(qualifier, field_name);
802 result.push(Expr::Column(col))
803 }
804 } else {
805 let (qualifier, field_name) = expr.qualified_name();
806 let col = Column::new(qualifier, field_name);
807 result.push(Expr::Column(col));
808 }
809}
810
811#[cfg(test)]
812mod test {
813 use std::any::Any;
814 use std::iter;
815
816 use arrow::datatypes::{DataType, Field, Schema};
817 use datafusion_expr::logical_plan::{JoinType, table_scan};
818 use datafusion_expr::{
819 AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs,
820 ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility,
821 grouping_set, is_null, not,
822 };
823 use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
824
825 use super::*;
826 use crate::assert_optimized_plan_eq_snapshot;
827 use crate::optimizer::OptimizerContext;
828 use crate::test::*;
829 use datafusion_expr::test::function_stub::{avg, sum};
830
831 macro_rules! assert_optimized_plan_equal {
832 (
833 $config:expr,
834 $plan:expr,
835 @ $expected:literal $(,)?
836 ) => {{
837 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
838 assert_optimized_plan_eq_snapshot!(
839 $config,
840 rules,
841 $plan,
842 @ $expected,
843 )
844 }};
845
846 (
847 $plan:expr,
848 @ $expected:literal $(,)?
849 ) => {{
850 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
851 let optimizer_ctx = OptimizerContext::new();
852 assert_optimized_plan_eq_snapshot!(
853 optimizer_ctx,
854 rules,
855 $plan,
856 @ $expected,
857 )
858 }};
859 }
860
861 #[test]
862 fn tpch_q1_simplified() -> Result<()> {
863 let table_scan = test_table_scan()?;
872
873 let plan = LogicalPlanBuilder::from(table_scan)
874 .aggregate(
875 iter::empty::<Expr>(),
876 vec![
877 sum(col("a") * (lit(1) - col("b"))),
878 sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
879 ],
880 )?
881 .build()?;
882
883 assert_optimized_plan_equal!(
884 plan,
885 @ r"
886 Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
887 Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
888 TableScan: test
889 "
890 )
891 }
892
893 #[test]
894 fn nested_aliases() -> Result<()> {
895 let table_scan = test_table_scan()?;
896
897 let plan = LogicalPlanBuilder::from(table_scan)
898 .project(vec![
899 (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
900 col("a") + col("b"),
901 ])?
902 .build()?;
903
904 assert_optimized_plan_equal!(
905 plan,
906 @ r"
907 Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
908 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
909 TableScan: test
910 "
911 )
912 }
913
914 #[test]
915 fn aggregate() -> Result<()> {
916 let table_scan = test_table_scan()?;
917
918 let return_type = DataType::UInt32;
919 let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
920 let udf_agg = |inner: Expr| {
921 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
922 Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
923 "my_agg",
924 Signature::exact(vec![DataType::UInt32], Volatility::Stable),
925 return_type.clone(),
926 Arc::clone(&accumulator),
927 vec![Field::new("value", DataType::UInt32, true).into()],
928 ))),
929 vec![inner],
930 false,
931 None,
932 vec![],
933 None,
934 ))
935 };
936
937 let plan = LogicalPlanBuilder::from(table_scan.clone())
939 .aggregate(
940 iter::empty::<Expr>(),
941 vec![
942 avg(col("a")).alias("col1"),
944 avg(col("a")).alias("col2"),
945 avg(col("b")).alias("col3"),
947 avg(col("c")),
948 udf_agg(col("a")).alias("col4"),
950 udf_agg(col("a")).alias("col5"),
951 udf_agg(col("b")).alias("col6"),
953 udf_agg(col("c")),
954 ],
955 )?
956 .build()?;
957
958 assert_optimized_plan_equal!(
959 plan,
960 @ r"
961 Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
962 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
963 TableScan: test
964 "
965 )?;
966
967 let plan = LogicalPlanBuilder::from(table_scan.clone())
969 .aggregate(
970 iter::empty::<Expr>(),
971 vec![
972 lit(1) + avg(col("a")),
973 lit(1) - avg(col("a")),
974 lit(1) + udf_agg(col("a")),
975 lit(1) - udf_agg(col("a")),
976 ],
977 )?
978 .build()?;
979
980 assert_optimized_plan_equal!(
981 plan,
982 @ r"
983 Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
984 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
985 TableScan: test
986 "
987 )?;
988
989 let plan = LogicalPlanBuilder::from(table_scan.clone())
991 .aggregate(
992 iter::empty::<Expr>(),
993 vec![
994 avg(lit(1u32) + col("a")).alias("col1"),
995 udf_agg(lit(1u32) + col("a")).alias("col2"),
996 ],
997 )?
998 .build()?;
999
1000 assert_optimized_plan_equal!(
1001 plan,
1002 @ r"
1003 Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1004 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1005 TableScan: test
1006 "
1007 )?;
1008
1009 let plan = LogicalPlanBuilder::from(table_scan.clone())
1011 .aggregate(
1012 vec![lit(1u32) + col("a")],
1013 vec![
1014 avg(lit(1u32) + col("a")).alias("col1"),
1015 udf_agg(lit(1u32) + col("a")).alias("col2"),
1016 ],
1017 )?
1018 .build()?;
1019
1020 assert_optimized_plan_equal!(
1021 plan,
1022 @ r"
1023 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1024 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1025 TableScan: test
1026 "
1027 )?;
1028
1029 let plan = LogicalPlanBuilder::from(table_scan)
1031 .aggregate(
1032 vec![lit(1u32) + col("a")],
1033 vec![
1034 (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1035 (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1036 avg(lit(1u32) + col("a")),
1037 (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1038 (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1039 udf_agg(lit(1u32) + col("a")),
1040 ],
1041 )?
1042 .build()?;
1043
1044 assert_optimized_plan_equal!(
1045 plan,
1046 @ r"
1047 Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1048 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1049 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1050 TableScan: test
1051 "
1052 )
1053 }
1054
1055 #[test]
1056 fn aggregate_with_relations_and_dots() -> Result<()> {
1057 let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1058 let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1059
1060 let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1061
1062 let plan = LogicalPlanBuilder::from(table_scan)
1063 .aggregate(
1064 vec![col_a.clone()],
1065 vec![
1066 (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1067 avg(lit(1u32) + col_a),
1068 ],
1069 )?
1070 .build()?;
1071
1072 assert_optimized_plan_equal!(
1073 plan,
1074 @ r"
1075 Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1076 Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1077 Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1078 TableScan: table.test
1079 "
1080 )
1081 }
1082
1083 #[test]
1084 fn subexpr_in_same_order() -> Result<()> {
1085 let table_scan = test_table_scan()?;
1086
1087 let plan = LogicalPlanBuilder::from(table_scan)
1088 .project(vec![
1089 (lit(1) + col("a")).alias("first"),
1090 (lit(1) + col("a")).alias("second"),
1091 ])?
1092 .build()?;
1093
1094 assert_optimized_plan_equal!(
1095 plan,
1096 @ r"
1097 Projection: __common_expr_1 AS first, __common_expr_1 AS second
1098 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1099 TableScan: test
1100 "
1101 )
1102 }
1103
1104 #[test]
1105 fn subexpr_in_different_order() -> Result<()> {
1106 let table_scan = test_table_scan()?;
1107
1108 let plan = LogicalPlanBuilder::from(table_scan)
1109 .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1110 .build()?;
1111
1112 assert_optimized_plan_equal!(
1113 plan,
1114 @ r"
1115 Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1116 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1117 TableScan: test
1118 "
1119 )
1120 }
1121
1122 #[test]
1123 fn cross_plans_subexpr() -> Result<()> {
1124 let table_scan = test_table_scan()?;
1125
1126 let plan = LogicalPlanBuilder::from(table_scan)
1127 .project(vec![lit(1) + col("a"), col("a")])?
1128 .project(vec![lit(1) + col("a")])?
1129 .build()?;
1130
1131 assert_optimized_plan_equal!(
1132 plan,
1133 @ r"
1134 Projection: Int32(1) + test.a
1135 Projection: Int32(1) + test.a, test.a
1136 TableScan: test
1137 "
1138 )
1139 }
1140
1141 #[test]
1142 fn redundant_project_fields() {
1143 let table_scan = test_table_scan().unwrap();
1144 let c_plus_a = col("c") + col("a");
1145 let b_plus_a = col("b") + col("a");
1146 let common_exprs_1 = vec![
1147 (c_plus_a, format!("{CSE_PREFIX}_1")),
1148 (b_plus_a, format!("{CSE_PREFIX}_2")),
1149 ];
1150 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1151 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1152 let common_exprs_2 = vec![
1153 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1154 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1155 ];
1156 let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1157 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1158
1159 let mut field_set = BTreeSet::new();
1160 for name in project_2.schema().field_names() {
1161 assert!(field_set.insert(name));
1162 }
1163 }
1164
1165 #[test]
1166 fn redundant_project_fields_join_input() {
1167 let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1168 let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1169 let join = LogicalPlanBuilder::from(table_scan_1)
1170 .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1171 .unwrap()
1172 .build()
1173 .unwrap();
1174 let c_plus_a = col("test1.c") + col("test1.a");
1175 let b_plus_a = col("test1.b") + col("test1.a");
1176 let common_exprs_1 = vec![
1177 (c_plus_a, format!("{CSE_PREFIX}_1")),
1178 (b_plus_a, format!("{CSE_PREFIX}_2")),
1179 ];
1180 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1181 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1182 let common_exprs_2 = vec![
1183 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1184 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1185 ];
1186 let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1187 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1188
1189 let mut field_set = BTreeSet::new();
1190 for name in project_2.schema().field_names() {
1191 assert!(field_set.insert(name));
1192 }
1193 }
1194
1195 #[test]
1196 fn eliminated_subexpr_datatype() {
1197 use datafusion_expr::cast;
1198
1199 let schema = Schema::new(vec![
1200 Field::new("a", DataType::UInt64, false),
1201 Field::new("b", DataType::UInt64, false),
1202 Field::new("c", DataType::UInt64, false),
1203 ]);
1204
1205 let plan = table_scan(Some("table"), &schema, None)
1206 .unwrap()
1207 .filter(
1208 cast(col("a"), DataType::Int64)
1209 .lt(lit(1_i64))
1210 .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1211 )
1212 .unwrap()
1213 .build()
1214 .unwrap();
1215 let rule = CommonSubexprEliminate::new();
1216 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1217 assert!(optimized_plan.transformed);
1218 let optimized_plan = optimized_plan.data;
1219
1220 let schema = optimized_plan.schema();
1221 let fields_with_datatypes: Vec<_> = schema
1222 .fields()
1223 .iter()
1224 .map(|field| (field.name(), field.data_type()))
1225 .collect();
1226 let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1227 let expected = r#"[
1228 (
1229 "a",
1230 UInt64,
1231 ),
1232 (
1233 "b",
1234 UInt64,
1235 ),
1236 (
1237 "c",
1238 UInt64,
1239 ),
1240]"#;
1241 assert_eq!(expected, formatted_fields_with_datatype);
1242 }
1243
1244 #[test]
1245 fn filter_schema_changed() -> Result<()> {
1246 let table_scan = test_table_scan()?;
1247
1248 let plan = LogicalPlanBuilder::from(table_scan)
1249 .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1250 .build()?;
1251
1252 assert_optimized_plan_equal!(
1253 plan,
1254 @ r"
1255 Projection: test.a, test.b, test.c
1256 Filter: __common_expr_1 - Int32(10) > __common_expr_1
1257 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1258 TableScan: test
1259 "
1260 )
1261 }
1262
1263 #[test]
1264 fn test_extract_expressions_from_grouping_set() -> Result<()> {
1265 let mut result = Vec::with_capacity(3);
1266 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1267 extract_expressions(&grouping, &mut result);
1268
1269 assert!(result.len() == 3);
1270 Ok(())
1271 }
1272
1273 #[test]
1274 fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1275 let mut result = Vec::with_capacity(2);
1276 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1277 extract_expressions(&grouping, &mut result);
1278 assert!(result.len() == 2);
1279 Ok(())
1280 }
1281
1282 #[test]
1283 fn test_alias_collision() -> Result<()> {
1284 let table_scan = test_table_scan()?;
1285
1286 let config = OptimizerContext::new();
1287 let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1288 let plan = LogicalPlanBuilder::from(table_scan.clone())
1289 .project(vec![
1290 (col("a") + col("b")).alias(common_expr_1.clone()),
1291 col("c"),
1292 ])?
1293 .project(vec![
1294 col(common_expr_1.clone()).alias("c1"),
1295 col(common_expr_1).alias("c2"),
1296 (col("c") + lit(2)).alias("c3"),
1297 (col("c") + lit(2)).alias("c4"),
1298 ])?
1299 .build()?;
1300
1301 assert_optimized_plan_equal!(
1302 config,
1303 plan,
1304 @ r"
1305 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1306 Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1307 Projection: test.a + test.b AS __common_expr_1, test.c
1308 TableScan: test
1309 "
1310 )?;
1311
1312 let config = OptimizerContext::new();
1313 let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1314 let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1315 let plan = LogicalPlanBuilder::from(table_scan)
1316 .project(vec![
1317 (col("a") + col("b")).alias(common_expr_2.clone()),
1318 col("c"),
1319 ])?
1320 .project(vec![
1321 col(common_expr_2.clone()).alias("c1"),
1322 col(common_expr_2).alias("c2"),
1323 (col("c") + lit(2)).alias("c3"),
1324 (col("c") + lit(2)).alias("c4"),
1325 ])?
1326 .build()?;
1327
1328 assert_optimized_plan_equal!(
1329 config,
1330 plan,
1331 @ r"
1332 Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1333 Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1334 Projection: test.a + test.b AS __common_expr_2, test.c
1335 TableScan: test
1336 "
1337 )?;
1338
1339 Ok(())
1340 }
1341
1342 #[test]
1343 fn test_extract_expressions_from_col() -> Result<()> {
1344 let mut result = Vec::with_capacity(1);
1345 extract_expressions(&col("a"), &mut result);
1346 assert!(result.len() == 1);
1347 Ok(())
1348 }
1349
1350 #[test]
1351 fn test_short_circuits() -> Result<()> {
1352 let table_scan = test_table_scan()?;
1353
1354 let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1355 let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1356 let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1357 let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1358 let plan = LogicalPlanBuilder::from(table_scan)
1359 .project(vec![
1360 extracted_short_circuit.clone().alias("c1"),
1361 extracted_short_circuit.alias("c2"),
1362 extracted_short_circuit_leg_1
1363 .clone()
1364 .or(not_extracted_short_circuit_leg_2.clone())
1365 .alias("c3"),
1366 extracted_short_circuit_leg_1
1367 .and(not_extracted_short_circuit_leg_2)
1368 .alias("c4"),
1369 extracted_short_circuit_leg_3
1370 .clone()
1371 .or(extracted_short_circuit_leg_3)
1372 .alias("c5"),
1373 ])?
1374 .build()?;
1375
1376 assert_optimized_plan_equal!(
1377 plan,
1378 @ r"
1379 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1380 Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1381 TableScan: test
1382 "
1383 )
1384 }
1385
1386 #[test]
1387 fn test_volatile() -> Result<()> {
1388 let table_scan = test_table_scan()?;
1389
1390 let extracted_child = col("a") + col("b");
1391 let rand = rand_func().call(vec![]);
1392 let not_extracted_volatile = extracted_child + rand;
1393 let plan = LogicalPlanBuilder::from(table_scan)
1394 .project(vec![
1395 not_extracted_volatile.clone().alias("c1"),
1396 not_extracted_volatile.alias("c2"),
1397 ])?
1398 .build()?;
1399
1400 assert_optimized_plan_equal!(
1401 plan,
1402 @ r"
1403 Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1404 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1405 TableScan: test
1406 "
1407 )
1408 }
1409
1410 #[test]
1411 fn test_volatile_short_circuits() -> Result<()> {
1412 let table_scan = test_table_scan()?;
1413
1414 let rand = rand_func().call(vec![]);
1415 let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1416 let not_extracted_volatile_short_circuit_1 =
1417 extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1418 let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1419 let not_extracted_volatile_short_circuit_2 =
1420 rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1421 let plan = LogicalPlanBuilder::from(table_scan)
1422 .project(vec![
1423 not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1424 not_extracted_volatile_short_circuit_1.alias("c2"),
1425 not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1426 not_extracted_volatile_short_circuit_2.alias("c4"),
1427 ])?
1428 .build()?;
1429
1430 assert_optimized_plan_equal!(
1431 plan,
1432 @ r"
1433 Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1434 Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1435 TableScan: test
1436 "
1437 )
1438 }
1439
1440 #[test]
1441 fn test_non_top_level_common_expression() -> Result<()> {
1442 let table_scan = test_table_scan()?;
1443
1444 let common_expr = col("a") + col("b");
1445 let plan = LogicalPlanBuilder::from(table_scan)
1446 .project(vec![
1447 common_expr.clone().alias("c1"),
1448 common_expr.alias("c2"),
1449 ])?
1450 .project(vec![col("c1"), col("c2")])?
1451 .build()?;
1452
1453 assert_optimized_plan_equal!(
1454 plan,
1455 @ r"
1456 Projection: c1, c2
1457 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1458 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1459 TableScan: test
1460 "
1461 )
1462 }
1463
1464 #[test]
1465 fn test_nested_common_expression() -> Result<()> {
1466 let table_scan = test_table_scan()?;
1467
1468 let nested_common_expr = col("a") + col("b");
1469 let common_expr = nested_common_expr.clone() * nested_common_expr;
1470 let plan = LogicalPlanBuilder::from(table_scan)
1471 .project(vec![
1472 common_expr.clone().alias("c1"),
1473 common_expr.alias("c2"),
1474 ])?
1475 .build()?;
1476
1477 assert_optimized_plan_equal!(
1478 plan,
1479 @ r"
1480 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1481 Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1482 Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1483 TableScan: test
1484 "
1485 )
1486 }
1487
1488 #[test]
1489 fn test_normalize_add_expression() -> Result<()> {
1490 let table_scan = test_table_scan()?;
1492 let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1493 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1494
1495 assert_optimized_plan_equal!(
1496 plan,
1497 @ r"
1498 Projection: test.a, test.b, test.c
1499 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1500 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1501 TableScan: test
1502 "
1503 )
1504 }
1505
1506 #[test]
1507 fn test_normalize_multi_expression() -> Result<()> {
1508 let table_scan = test_table_scan()?;
1510 let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1511 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1512
1513 assert_optimized_plan_equal!(
1514 plan,
1515 @ r"
1516 Projection: test.a, test.b, test.c
1517 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1518 Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1519 TableScan: test
1520 "
1521 )
1522 }
1523
1524 #[test]
1525 fn test_normalize_bitset_and_expression() -> Result<()> {
1526 let table_scan = test_table_scan()?;
1528 let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1529 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1530
1531 assert_optimized_plan_equal!(
1532 plan,
1533 @ r"
1534 Projection: test.a, test.b, test.c
1535 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1536 Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1537 TableScan: test
1538 "
1539 )
1540 }
1541
1542 #[test]
1543 fn test_normalize_bitset_or_expression() -> Result<()> {
1544 let table_scan = test_table_scan()?;
1546 let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1547 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1548
1549 assert_optimized_plan_equal!(
1550 plan,
1551 @ r"
1552 Projection: test.a, test.b, test.c
1553 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1554 Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1555 TableScan: test
1556 "
1557 )
1558 }
1559
1560 #[test]
1561 fn test_normalize_bitset_xor_expression() -> Result<()> {
1562 let table_scan = test_table_scan()?;
1564 let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1565 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1566
1567 assert_optimized_plan_equal!(
1568 plan,
1569 @ r"
1570 Projection: test.a, test.b, test.c
1571 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1572 Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1573 TableScan: test
1574 "
1575 )
1576 }
1577
1578 #[test]
1579 fn test_normalize_eq_expression() -> Result<()> {
1580 let table_scan = test_table_scan()?;
1582 let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1583 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1584
1585 assert_optimized_plan_equal!(
1586 plan,
1587 @ r"
1588 Projection: test.a, test.b, test.c
1589 Filter: __common_expr_1 AND __common_expr_1
1590 Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1591 TableScan: test
1592 "
1593 )
1594 }
1595
1596 #[test]
1597 fn test_normalize_ne_expression() -> Result<()> {
1598 let table_scan = test_table_scan()?;
1600 let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1601 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1602
1603 assert_optimized_plan_equal!(
1604 plan,
1605 @ r"
1606 Projection: test.a, test.b, test.c
1607 Filter: __common_expr_1 AND __common_expr_1
1608 Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1609 TableScan: test
1610 "
1611 )
1612 }
1613
1614 #[test]
1615 fn test_normalize_complex_expression() -> Result<()> {
1616 let table_scan = test_table_scan()?;
1618 let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1619 .eq(lit(30));
1620 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1621
1622 assert_optimized_plan_equal!(
1623 plan,
1624 @ r"
1625 Projection: test.a, test.b, test.c
1626 Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1627 Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1628 TableScan: test
1629 "
1630 )?;
1631
1632 let table_scan = test_table_scan()?;
1634 let expr = (((col("a") + col("b") / col("c")) * col("c"))
1635 / (col("c") * (col("b") / col("c") + col("a")))
1636 + col("a"))
1637 .eq(lit(30));
1638 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1639
1640 assert_optimized_plan_equal!(
1641 plan,
1642 @ r"
1643 Projection: test.a, test.b, test.c
1644 Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1645 Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1646 TableScan: test
1647 "
1648 )?;
1649
1650 let table_scan = test_table_scan()?;
1652 let expr = ((col("b") / (col("a") + col("c")))
1653 * (col("b") / (col("c") + col("a"))))
1654 .eq(lit(30));
1655 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1656 assert_optimized_plan_equal!(
1657 plan,
1658 @ r"
1659 Projection: test.a, test.b, test.c
1660 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1661 Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1662 TableScan: test
1663 "
1664 )?;
1665
1666 Ok(())
1667 }
1668
1669 #[derive(Debug, PartialEq, Eq, Hash)]
1670 pub struct TestUdf {
1671 signature: Signature,
1672 }
1673
1674 impl TestUdf {
1675 pub fn new() -> Self {
1676 Self {
1677 signature: Signature::numeric(1, Volatility::Immutable),
1678 }
1679 }
1680 }
1681
1682 impl ScalarUDFImpl for TestUdf {
1683 fn as_any(&self) -> &dyn Any {
1684 self
1685 }
1686 fn name(&self) -> &str {
1687 "my_udf"
1688 }
1689
1690 fn signature(&self) -> &Signature {
1691 &self.signature
1692 }
1693
1694 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1695 Ok(DataType::Int32)
1696 }
1697
1698 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1699 panic!("not implemented")
1700 }
1701 }
1702
1703 #[test]
1704 fn test_normalize_inner_binary_expression() -> Result<()> {
1705 let table_scan = test_table_scan()?;
1707 let expr1 = not(col("a").eq(col("b")));
1708 let expr2 = not(col("b").eq(col("a")));
1709 let plan = LogicalPlanBuilder::from(table_scan)
1710 .project(vec![expr1, expr2])?
1711 .build()?;
1712 assert_optimized_plan_equal!(
1713 plan,
1714 @ r"
1715 Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1716 Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1717 TableScan: test
1718 "
1719 )?;
1720
1721 let table_scan = test_table_scan()?;
1723 let expr1 = is_null(col("a").eq(col("b")));
1724 let expr2 = is_null(col("b").eq(col("a")));
1725 let plan = LogicalPlanBuilder::from(table_scan)
1726 .project(vec![expr1, expr2])?
1727 .build()?;
1728 assert_optimized_plan_equal!(
1729 plan,
1730 @ r"
1731 Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1732 Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1733 TableScan: test
1734 "
1735 )?;
1736
1737 let table_scan = test_table_scan()?;
1739 let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1740 let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1741 let plan = LogicalPlanBuilder::from(table_scan)
1742 .project(vec![expr1, expr2])?
1743 .build()?;
1744 assert_optimized_plan_equal!(
1745 plan,
1746 @ r"
1747 Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1748 Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1749 TableScan: test
1750 "
1751 )?;
1752
1753 let table_scan = test_table_scan()?;
1755 let expr1 = col("c").between(col("a") + col("b"), lit(10));
1756 let expr2 = col("c").between(col("b") + col("a"), lit(10));
1757 let plan = LogicalPlanBuilder::from(table_scan)
1758 .project(vec![expr1, expr2])?
1759 .build()?;
1760 assert_optimized_plan_equal!(
1761 plan,
1762 @ r"
1763 Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1764 Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1765 TableScan: test
1766 "
1767 )?;
1768
1769 let udf = ScalarUDF::from(TestUdf::new());
1771 let table_scan = test_table_scan()?;
1772 let expr1 = udf.call(vec![col("a") + col("b")]);
1773 let expr2 = udf.call(vec![col("b") + col("a")]);
1774 let plan = LogicalPlanBuilder::from(table_scan)
1775 .project(vec![expr1, expr2])?
1776 .build()?;
1777 assert_optimized_plan_equal!(
1778 plan,
1779 @ r"
1780 Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1781 Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1782 TableScan: test
1783 "
1784 )
1785 }
1786
1787 fn rand_func() -> ScalarUDF {
1793 ScalarUDF::new_from_impl(RandomStub::new())
1794 }
1795
1796 #[derive(Debug, PartialEq, Eq, Hash)]
1797 struct RandomStub {
1798 signature: Signature,
1799 }
1800
1801 impl RandomStub {
1802 fn new() -> Self {
1803 Self {
1804 signature: Signature::exact(vec![], Volatility::Volatile),
1805 }
1806 }
1807 }
1808 impl ScalarUDFImpl for RandomStub {
1809 fn as_any(&self) -> &dyn Any {
1810 self
1811 }
1812
1813 fn name(&self) -> &str {
1814 "random"
1815 }
1816
1817 fn signature(&self) -> &Signature {
1818 &self.signature
1819 }
1820
1821 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1822 Ok(DataType::Float64)
1823 }
1824
1825 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1826 panic!("dummy - not implemented")
1827 }
1828 }
1829}