1use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35 Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71 pub fn new() -> Self {
72 Self {}
73 }
74
75 fn try_optimize_proj(
76 &self,
77 projection: Projection,
78 config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 let Projection {
81 expr,
82 input,
83 schema,
84 ..
85 } = projection;
86 let input = Arc::unwrap_or_clone(input);
87 self.try_unary_plan(expr, input, config)?
88 .map_data(|(new_expr, new_input)| {
89 Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90 .map(LogicalPlan::Projection)
91 })
92 }
93
94 fn try_optimize_sort(
95 &self,
96 sort: Sort,
97 config: &dyn OptimizerConfig,
98 ) -> Result<Transformed<LogicalPlan>> {
99 let Sort { expr, input, fetch } = sort;
100 let input = Arc::unwrap_or_clone(input);
101 let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102 .into_iter()
103 .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104 .unzip();
105 let new_sort = self
106 .try_unary_plan(sort_expressions, input, config)?
107 .update_data(|(new_expr, new_input)| {
108 LogicalPlan::Sort(Sort {
109 expr: new_expr
110 .into_iter()
111 .zip(sort_params)
112 .map(|(expr, (asc, nulls_first))| SortExpr {
113 expr,
114 asc,
115 nulls_first,
116 })
117 .collect(),
118 input: Arc::new(new_input),
119 fetch,
120 })
121 });
122 Ok(new_sort)
123 }
124
125 fn try_optimize_filter(
126 &self,
127 filter: Filter,
128 config: &dyn OptimizerConfig,
129 ) -> Result<Transformed<LogicalPlan>> {
130 let Filter {
131 predicate, input, ..
132 } = filter;
133 let input = Arc::unwrap_or_clone(input);
134 let expr = vec![predicate];
135 self.try_unary_plan(expr, input, config)?
136 .map_data(|(mut new_expr, new_input)| {
137 assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
139 Filter::try_new(new_predicate, Arc::new(new_input))
140 .map(LogicalPlan::Filter)
141 })
142 }
143
144 fn try_optimize_window(
145 &self,
146 window: Window,
147 config: &dyn OptimizerConfig,
148 ) -> Result<Transformed<LogicalPlan>> {
149 let (window_expr_list, window_schemas, input) =
152 get_consecutive_window_exprs(window);
153
154 match CSE::new(ExprCSEController::new(
157 config.alias_generator().as_ref(),
158 ExprMask::Normal,
159 ))
160 .extract_common_nodes(window_expr_list)?
161 {
162 FoundCommonNodes::Yes {
166 common_nodes: common_exprs,
167 new_nodes_list: new_exprs_list,
168 original_nodes_list: original_exprs_list,
169 } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170 Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171 }),
172 FoundCommonNodes::No {
173 original_nodes_list: original_exprs_list,
174 } => Ok(Transformed::no((original_exprs_list, input, None))),
175 }?
176 .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179 self.rewrite(new_input, config)?.map_data(|new_input| {
180 Ok((new_window_expr_list, new_input, window_expr_list))
181 })
182 })?
183 .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185 if let Some(window_expr_list) = window_expr_list {
194 let name_preserver = NamePreserver::new_for_projection();
195 let saved_names = window_expr_list
196 .iter()
197 .map(|exprs| {
198 exprs
199 .iter()
200 .map(|expr| name_preserver.save(expr))
201 .collect::<Vec<_>>()
202 })
203 .collect::<Vec<_>>();
204 new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205 new_input,
206 |plan, (new_window_expr, saved_names)| {
207 let new_window_expr = new_window_expr
208 .into_iter()
209 .zip(saved_names)
210 .map(|(new_window_expr, saved_name)| {
211 saved_name.restore(new_window_expr)
212 })
213 .collect::<Vec<_>>();
214 Window::try_new(new_window_expr, Arc::new(plan))
215 .map(LogicalPlan::Window)
216 },
217 )
218 } else {
219 new_window_expr_list
220 .into_iter()
221 .zip(window_schemas)
222 .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223 Window::try_new_with_schema(
224 new_window_expr,
225 Arc::new(plan),
226 schema,
227 )
228 .map(LogicalPlan::Window)
229 })
230 }
231 })
232 }
233
234 fn try_optimize_aggregate(
235 &self,
236 aggregate: Aggregate,
237 config: &dyn OptimizerConfig,
238 ) -> Result<Transformed<LogicalPlan>> {
239 let Aggregate {
240 group_expr,
241 aggr_expr,
242 input,
243 schema,
244 ..
245 } = aggregate;
246 let input = Arc::unwrap_or_clone(input);
247 match CSE::new(ExprCSEController::new(
249 config.alias_generator().as_ref(),
250 ExprMask::Normal,
251 ))
252 .extract_common_nodes(vec![group_expr, aggr_expr])?
253 {
254 FoundCommonNodes::Yes {
258 common_nodes: common_exprs,
259 new_nodes_list: mut new_exprs_list,
260 original_nodes_list: mut original_exprs_list,
261 } => {
262 let new_aggr_expr = new_exprs_list.pop().unwrap();
263 let new_group_expr = new_exprs_list.pop().unwrap();
264
265 build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266 let aggr_expr = original_exprs_list.pop().unwrap();
267 Transformed::yes((
268 new_aggr_expr,
269 new_group_expr,
270 new_input,
271 Some(aggr_expr),
272 ))
273 })
274 }
275
276 FoundCommonNodes::No {
277 original_nodes_list: mut original_exprs_list,
278 } => {
279 let new_aggr_expr = original_exprs_list.pop().unwrap();
280 let new_group_expr = original_exprs_list.pop().unwrap();
281
282 Ok(Transformed::no((
283 new_aggr_expr,
284 new_group_expr,
285 input,
286 None,
287 )))
288 }
289 }?
290 .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293 self.rewrite(new_input, config)?.map_data(|new_input| {
294 Ok((
295 new_aggr_expr,
296 new_group_expr,
297 aggr_expr,
298 Arc::new(new_input),
299 ))
300 })
301 })?
302 .transform_data(
304 |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305 match CSE::new(ExprCSEController::new(
307 config.alias_generator().as_ref(),
308 ExprMask::NormalAndAggregates,
309 ))
310 .extract_common_nodes(vec![new_aggr_expr])?
311 {
312 FoundCommonNodes::Yes {
313 common_nodes: common_exprs,
314 new_nodes_list: mut new_exprs_list,
315 original_nodes_list: mut original_exprs_list,
316 } => {
317 let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318 let new_aggr_expr = original_exprs_list.pop().unwrap();
319 let saved_names = if let Some(aggr_expr) = aggr_expr {
320 let name_preserver = NamePreserver::new_for_projection();
321 aggr_expr
322 .iter()
323 .map(|expr| Some(name_preserver.save(expr)))
324 .collect::<Vec<_>>()
325 } else {
326 new_aggr_expr
327 .clone()
328 .into_iter()
329 .map(|_| None)
330 .collect::<Vec<_>>()
331 };
332
333 let mut agg_exprs = common_exprs
334 .into_iter()
335 .map(|(expr, expr_alias)| expr.alias(expr_alias))
336 .collect::<Vec<_>>();
337
338 let mut proj_exprs = vec![];
339 for expr in &new_group_expr {
340 extract_expressions(expr, &mut proj_exprs)
341 }
342 for ((expr_rewritten, expr_orig), saved_name) in
343 rewritten_aggr_expr
344 .into_iter()
345 .zip(new_aggr_expr)
346 .zip(saved_names)
347 {
348 if expr_rewritten == expr_orig {
349 let expr_rewritten = if let Some(saved_name) = saved_name
350 {
351 saved_name.restore(expr_rewritten)
352 } else {
353 expr_rewritten
354 };
355 if let Expr::Alias(Alias { expr, name, .. }) =
356 expr_rewritten
357 {
358 agg_exprs.push(expr.alias(&name));
359 proj_exprs
360 .push(Expr::Column(Column::from_name(name)));
361 } else {
362 let expr_alias =
363 config.alias_generator().next(CSE_PREFIX);
364 let (qualifier, field_name) =
365 expr_rewritten.qualified_name();
366 let out_name =
367 qualified_name(qualifier.as_ref(), &field_name);
368
369 agg_exprs.push(expr_rewritten.alias(&expr_alias));
370 proj_exprs.push(
371 Expr::Column(Column::from_name(expr_alias))
372 .alias(out_name),
373 );
374 }
375 } else {
376 proj_exprs.push(expr_rewritten);
377 }
378 }
379
380 let agg = LogicalPlan::Aggregate(Aggregate::try_new(
381 new_input,
382 new_group_expr,
383 agg_exprs,
384 )?);
385 Projection::try_new(proj_exprs, Arc::new(agg))
386 .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
387 }
388
389 FoundCommonNodes::No {
392 original_nodes_list: mut original_exprs_list,
393 } => {
394 let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
395
396 if let Some(aggr_expr) = aggr_expr {
407 let name_preserver = NamePreserver::new_for_projection();
408 let saved_names = aggr_expr
409 .iter()
410 .map(|expr| name_preserver.save(expr))
411 .collect::<Vec<_>>();
412 let new_aggr_expr = rewritten_aggr_expr
413 .into_iter()
414 .zip(saved_names)
415 .map(|(new_expr, saved_name)| {
416 saved_name.restore(new_expr)
417 })
418 .collect::<Vec<Expr>>();
419
420 Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
423 .map(LogicalPlan::Aggregate)
424 .map(Transformed::no)
425 } else {
426 Aggregate::try_new_with_schema(
427 new_input,
428 new_group_expr,
429 rewritten_aggr_expr,
430 schema,
431 )
432 .map(LogicalPlan::Aggregate)
433 .map(Transformed::no)
434 }
435 }
436 }
437 },
438 )
439 }
440
441 fn try_unary_plan(
456 &self,
457 exprs: Vec<Expr>,
458 input: LogicalPlan,
459 config: &dyn OptimizerConfig,
460 ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
461 match CSE::new(ExprCSEController::new(
463 config.alias_generator().as_ref(),
464 ExprMask::Normal,
465 ))
466 .extract_common_nodes(vec![exprs])?
467 {
468 FoundCommonNodes::Yes {
469 common_nodes: common_exprs,
470 new_nodes_list: mut new_exprs_list,
471 original_nodes_list: _,
472 } => {
473 let new_exprs = new_exprs_list.pop().unwrap();
474 build_common_expr_project_plan(input, common_exprs)
475 .map(|new_input| Transformed::yes((new_exprs, new_input)))
476 }
477 FoundCommonNodes::No {
478 original_nodes_list: mut original_exprs_list,
479 } => {
480 let new_exprs = original_exprs_list.pop().unwrap();
481 Ok(Transformed::no((new_exprs, input)))
482 }
483 }?
484 .transform_data(|(new_exprs, new_input)| {
487 self.rewrite(new_input, config)?
488 .map_data(|new_input| Ok((new_exprs, new_input)))
489 })
490 }
491}
492
493fn get_consecutive_window_exprs(
525 window: Window,
526) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
527 let mut window_expr_list = vec![];
528 let mut window_schemas = vec![];
529 let mut plan = LogicalPlan::Window(window);
530 while let LogicalPlan::Window(Window {
531 input,
532 window_expr,
533 schema,
534 }) = plan
535 {
536 window_expr_list.push(window_expr);
537 window_schemas.push(schema);
538
539 plan = Arc::unwrap_or_clone(input);
540 }
541 (window_expr_list, window_schemas, plan)
542}
543
544impl OptimizerRule for CommonSubexprEliminate {
545 fn supports_rewrite(&self) -> bool {
546 true
547 }
548
549 fn apply_order(&self) -> Option<ApplyOrder> {
550 None
554 }
555
556 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
557 fn rewrite(
558 &self,
559 plan: LogicalPlan,
560 config: &dyn OptimizerConfig,
561 ) -> Result<Transformed<LogicalPlan>> {
562 let original_schema = Arc::clone(plan.schema());
563
564 let optimized_plan = match plan {
565 LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
566 LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
567 LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
568 LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
569 LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
570 LogicalPlan::Join(_)
571 | LogicalPlan::Repartition(_)
572 | LogicalPlan::Union(_)
573 | LogicalPlan::TableScan(_)
574 | LogicalPlan::Values(_)
575 | LogicalPlan::EmptyRelation(_)
576 | LogicalPlan::Subquery(_)
577 | LogicalPlan::SubqueryAlias(_)
578 | LogicalPlan::Limit(_)
579 | LogicalPlan::Ddl(_)
580 | LogicalPlan::Explain(_)
581 | LogicalPlan::Analyze(_)
582 | LogicalPlan::Statement(_)
583 | LogicalPlan::DescribeTable(_)
584 | LogicalPlan::Distinct(_)
585 | LogicalPlan::Extension(_)
586 | LogicalPlan::Dml(_)
587 | LogicalPlan::Copy(_)
588 | LogicalPlan::Unnest(_)
589 | LogicalPlan::RecursiveQuery(_) => {
590 plan.map_children(|c| self.rewrite(c, config))?
593 }
594 };
595
596 if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
598 {
599 optimized_plan.map_data(|optimized_plan| {
600 build_recover_project_plan(&original_schema, optimized_plan)
601 })
602 } else {
603 Ok(optimized_plan)
604 }
605 }
606
607 fn name(&self) -> &str {
608 "common_sub_expression_eliminate"
609 }
610}
611
612#[derive(Debug, Clone, Copy)]
614enum ExprMask {
615 Normal,
624
625 NormalAndAggregates,
627}
628
629struct ExprCSEController<'a> {
630 alias_generator: &'a AliasGenerator,
631 mask: ExprMask,
632
633 alias_counter: usize,
635}
636
637impl<'a> ExprCSEController<'a> {
638 fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
639 Self {
640 alias_generator,
641 mask,
642 alias_counter: 0,
643 }
644 }
645}
646
647impl CSEController for ExprCSEController<'_> {
648 type Node = Expr;
649
650 fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
651 match node {
652 Expr::ScalarFunction(ScalarFunction { func, args })
656 if func.short_circuits() =>
657 {
658 Some((vec![], args.iter().collect()))
659 }
660
661 Expr::BinaryExpr(BinaryExpr {
664 left,
665 op: Operator::And | Operator::Or,
666 right,
667 }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
668
669 Expr::Case(Case {
673 expr,
674 when_then_expr,
675 else_expr,
676 }) => Some((
677 expr.iter()
678 .map(|e| e.as_ref())
679 .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
680 .collect(),
681 when_then_expr
682 .iter()
683 .take(1)
684 .map(|(_, then)| then.as_ref())
685 .chain(
686 when_then_expr
687 .iter()
688 .skip(1)
689 .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
690 )
691 .chain(else_expr.iter().map(|e| e.as_ref()))
692 .collect(),
693 )),
694 _ => None,
695 }
696 }
697
698 fn is_valid(node: &Expr) -> bool {
699 !node.is_volatile_node()
700 }
701
702 fn is_ignored(&self, node: &Expr) -> bool {
703 #[expect(deprecated)]
705 let is_normal_minus_aggregates = matches!(
706 node,
707 Expr::Literal(..)
708 | Expr::Column(..)
709 | Expr::ScalarVariable(..)
710 | Expr::Alias(..)
711 | Expr::Wildcard { .. }
712 );
713
714 let is_aggr = matches!(node, Expr::AggregateFunction(..));
715
716 match self.mask {
717 ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
718 ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
719 }
720 }
721
722 fn generate_alias(&self) -> String {
723 self.alias_generator.next(CSE_PREFIX)
724 }
725
726 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
727 if self.alias_counter > 0 {
729 col(alias)
730 } else {
731 self.alias_counter += 1;
732 col(alias).alias(node.schema_name().to_string())
733 }
734 }
735
736 fn rewrite_f_down(&mut self, node: &Expr) {
737 if matches!(node, Expr::Alias(_)) {
738 self.alias_counter += 1;
739 }
740 }
741 fn rewrite_f_up(&mut self, node: &Expr) {
742 if matches!(node, Expr::Alias(_)) {
743 self.alias_counter -= 1
744 }
745 }
746}
747
748impl Default for CommonSubexprEliminate {
749 fn default() -> Self {
750 Self::new()
751 }
752}
753
754fn build_common_expr_project_plan(
765 input: LogicalPlan,
766 common_exprs: Vec<(Expr, String)>,
767) -> Result<LogicalPlan> {
768 let mut fields_set = BTreeSet::new();
769 let mut project_exprs = common_exprs
770 .into_iter()
771 .map(|(expr, expr_alias)| {
772 fields_set.insert(expr_alias.clone());
773 Ok(expr.alias(expr_alias))
774 })
775 .collect::<Result<Vec<_>>>()?;
776
777 for (qualifier, field) in input.schema().iter() {
778 if fields_set.insert(qualified_name(qualifier, field.name())) {
779 project_exprs.push(Expr::from((qualifier, field)));
780 }
781 }
782
783 Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
784}
785
786fn build_recover_project_plan(
792 schema: &DFSchema,
793 input: LogicalPlan,
794) -> Result<LogicalPlan> {
795 let col_exprs = schema.iter().map(Expr::from).collect();
796 Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
797}
798
799fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
800 if let Expr::GroupingSet(groupings) = expr {
801 for e in groupings.distinct_expr() {
802 let (qualifier, field_name) = e.qualified_name();
803 let col = Column::new(qualifier, field_name);
804 result.push(Expr::Column(col))
805 }
806 } else {
807 let (qualifier, field_name) = expr.qualified_name();
808 let col = Column::new(qualifier, field_name);
809 result.push(Expr::Column(col));
810 }
811}
812
813#[cfg(test)]
814mod test {
815 use std::any::Any;
816 use std::iter;
817
818 use arrow::datatypes::{DataType, Field, Schema};
819 use datafusion_expr::logical_plan::{table_scan, JoinType};
820 use datafusion_expr::{
821 grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
822 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
823 SimpleAggregateUDF, Volatility,
824 };
825 use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
826
827 use super::*;
828 use crate::assert_optimized_plan_eq_snapshot;
829 use crate::optimizer::OptimizerContext;
830 use crate::test::*;
831 use datafusion_expr::test::function_stub::{avg, sum};
832
833 macro_rules! assert_optimized_plan_equal {
834 (
835 $config:expr,
836 $plan:expr,
837 @ $expected:literal $(,)?
838 ) => {{
839 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
840 assert_optimized_plan_eq_snapshot!(
841 $config,
842 rules,
843 $plan,
844 @ $expected,
845 )
846 }};
847
848 (
849 $plan:expr,
850 @ $expected:literal $(,)?
851 ) => {{
852 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
853 let optimizer_ctx = OptimizerContext::new();
854 assert_optimized_plan_eq_snapshot!(
855 optimizer_ctx,
856 rules,
857 $plan,
858 @ $expected,
859 )
860 }};
861 }
862
863 #[test]
864 fn tpch_q1_simplified() -> Result<()> {
865 let table_scan = test_table_scan()?;
874
875 let plan = LogicalPlanBuilder::from(table_scan)
876 .aggregate(
877 iter::empty::<Expr>(),
878 vec![
879 sum(col("a") * (lit(1) - col("b"))),
880 sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
881 ],
882 )?
883 .build()?;
884
885 assert_optimized_plan_equal!(
886 plan,
887 @ r"
888 Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
889 Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
890 TableScan: test
891 "
892 )
893 }
894
895 #[test]
896 fn nested_aliases() -> Result<()> {
897 let table_scan = test_table_scan()?;
898
899 let plan = LogicalPlanBuilder::from(table_scan)
900 .project(vec![
901 (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
902 col("a") + col("b"),
903 ])?
904 .build()?;
905
906 assert_optimized_plan_equal!(
907 plan,
908 @ r"
909 Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
910 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
911 TableScan: test
912 "
913 )
914 }
915
916 #[test]
917 fn aggregate() -> Result<()> {
918 let table_scan = test_table_scan()?;
919
920 let return_type = DataType::UInt32;
921 let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
922 let udf_agg = |inner: Expr| {
923 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
924 Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
925 "my_agg",
926 Signature::exact(vec![DataType::UInt32], Volatility::Stable),
927 return_type.clone(),
928 Arc::clone(&accumulator),
929 vec![Field::new("value", DataType::UInt32, true).into()],
930 ))),
931 vec![inner],
932 false,
933 None,
934 vec![],
935 None,
936 ))
937 };
938
939 let plan = LogicalPlanBuilder::from(table_scan.clone())
941 .aggregate(
942 iter::empty::<Expr>(),
943 vec![
944 avg(col("a")).alias("col1"),
946 avg(col("a")).alias("col2"),
947 avg(col("b")).alias("col3"),
949 avg(col("c")),
950 udf_agg(col("a")).alias("col4"),
952 udf_agg(col("a")).alias("col5"),
953 udf_agg(col("b")).alias("col6"),
955 udf_agg(col("c")),
956 ],
957 )?
958 .build()?;
959
960 assert_optimized_plan_equal!(
961 plan,
962 @ r"
963 Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
964 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
965 TableScan: test
966 "
967 )?;
968
969 let plan = LogicalPlanBuilder::from(table_scan.clone())
971 .aggregate(
972 iter::empty::<Expr>(),
973 vec![
974 lit(1) + avg(col("a")),
975 lit(1) - avg(col("a")),
976 lit(1) + udf_agg(col("a")),
977 lit(1) - udf_agg(col("a")),
978 ],
979 )?
980 .build()?;
981
982 assert_optimized_plan_equal!(
983 plan,
984 @ r"
985 Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
986 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
987 TableScan: test
988 "
989 )?;
990
991 let plan = LogicalPlanBuilder::from(table_scan.clone())
993 .aggregate(
994 iter::empty::<Expr>(),
995 vec![
996 avg(lit(1u32) + col("a")).alias("col1"),
997 udf_agg(lit(1u32) + col("a")).alias("col2"),
998 ],
999 )?
1000 .build()?;
1001
1002 assert_optimized_plan_equal!(
1003 plan,
1004 @ r"
1005 Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1006 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1007 TableScan: test
1008 "
1009 )?;
1010
1011 let plan = LogicalPlanBuilder::from(table_scan.clone())
1013 .aggregate(
1014 vec![lit(1u32) + col("a")],
1015 vec![
1016 avg(lit(1u32) + col("a")).alias("col1"),
1017 udf_agg(lit(1u32) + col("a")).alias("col2"),
1018 ],
1019 )?
1020 .build()?;
1021
1022 assert_optimized_plan_equal!(
1023 plan,
1024 @ r"
1025 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1026 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1027 TableScan: test
1028 "
1029 )?;
1030
1031 let plan = LogicalPlanBuilder::from(table_scan)
1033 .aggregate(
1034 vec![lit(1u32) + col("a")],
1035 vec![
1036 (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1037 (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1038 avg(lit(1u32) + col("a")),
1039 (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1040 (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1041 udf_agg(lit(1u32) + col("a")),
1042 ],
1043 )?
1044 .build()?;
1045
1046 assert_optimized_plan_equal!(
1047 plan,
1048 @ r"
1049 Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1050 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1051 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1052 TableScan: test
1053 "
1054 )
1055 }
1056
1057 #[test]
1058 fn aggregate_with_relations_and_dots() -> Result<()> {
1059 let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1060 let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1061
1062 let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1063
1064 let plan = LogicalPlanBuilder::from(table_scan)
1065 .aggregate(
1066 vec![col_a.clone()],
1067 vec![
1068 (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1069 avg(lit(1u32) + col_a),
1070 ],
1071 )?
1072 .build()?;
1073
1074 assert_optimized_plan_equal!(
1075 plan,
1076 @ r"
1077 Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1078 Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1079 Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1080 TableScan: table.test
1081 "
1082 )
1083 }
1084
1085 #[test]
1086 fn subexpr_in_same_order() -> Result<()> {
1087 let table_scan = test_table_scan()?;
1088
1089 let plan = LogicalPlanBuilder::from(table_scan)
1090 .project(vec![
1091 (lit(1) + col("a")).alias("first"),
1092 (lit(1) + col("a")).alias("second"),
1093 ])?
1094 .build()?;
1095
1096 assert_optimized_plan_equal!(
1097 plan,
1098 @ r"
1099 Projection: __common_expr_1 AS first, __common_expr_1 AS second
1100 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1101 TableScan: test
1102 "
1103 )
1104 }
1105
1106 #[test]
1107 fn subexpr_in_different_order() -> Result<()> {
1108 let table_scan = test_table_scan()?;
1109
1110 let plan = LogicalPlanBuilder::from(table_scan)
1111 .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1112 .build()?;
1113
1114 assert_optimized_plan_equal!(
1115 plan,
1116 @ r"
1117 Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1118 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1119 TableScan: test
1120 "
1121 )
1122 }
1123
1124 #[test]
1125 fn cross_plans_subexpr() -> Result<()> {
1126 let table_scan = test_table_scan()?;
1127
1128 let plan = LogicalPlanBuilder::from(table_scan)
1129 .project(vec![lit(1) + col("a"), col("a")])?
1130 .project(vec![lit(1) + col("a")])?
1131 .build()?;
1132
1133 assert_optimized_plan_equal!(
1134 plan,
1135 @ r"
1136 Projection: Int32(1) + test.a
1137 Projection: Int32(1) + test.a, test.a
1138 TableScan: test
1139 "
1140 )
1141 }
1142
1143 #[test]
1144 fn redundant_project_fields() {
1145 let table_scan = test_table_scan().unwrap();
1146 let c_plus_a = col("c") + col("a");
1147 let b_plus_a = col("b") + col("a");
1148 let common_exprs_1 = vec![
1149 (c_plus_a, format!("{CSE_PREFIX}_1")),
1150 (b_plus_a, format!("{CSE_PREFIX}_2")),
1151 ];
1152 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1153 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1154 let common_exprs_2 = vec![
1155 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1156 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1157 ];
1158 let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1159 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1160
1161 let mut field_set = BTreeSet::new();
1162 for name in project_2.schema().field_names() {
1163 assert!(field_set.insert(name));
1164 }
1165 }
1166
1167 #[test]
1168 fn redundant_project_fields_join_input() {
1169 let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1170 let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1171 let join = LogicalPlanBuilder::from(table_scan_1)
1172 .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1173 .unwrap()
1174 .build()
1175 .unwrap();
1176 let c_plus_a = col("test1.c") + col("test1.a");
1177 let b_plus_a = col("test1.b") + col("test1.a");
1178 let common_exprs_1 = vec![
1179 (c_plus_a, format!("{CSE_PREFIX}_1")),
1180 (b_plus_a, format!("{CSE_PREFIX}_2")),
1181 ];
1182 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1183 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1184 let common_exprs_2 = vec![
1185 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1186 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1187 ];
1188 let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1189 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1190
1191 let mut field_set = BTreeSet::new();
1192 for name in project_2.schema().field_names() {
1193 assert!(field_set.insert(name));
1194 }
1195 }
1196
1197 #[test]
1198 fn eliminated_subexpr_datatype() {
1199 use datafusion_expr::cast;
1200
1201 let schema = Schema::new(vec![
1202 Field::new("a", DataType::UInt64, false),
1203 Field::new("b", DataType::UInt64, false),
1204 Field::new("c", DataType::UInt64, false),
1205 ]);
1206
1207 let plan = table_scan(Some("table"), &schema, None)
1208 .unwrap()
1209 .filter(
1210 cast(col("a"), DataType::Int64)
1211 .lt(lit(1_i64))
1212 .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1213 )
1214 .unwrap()
1215 .build()
1216 .unwrap();
1217 let rule = CommonSubexprEliminate::new();
1218 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1219 assert!(optimized_plan.transformed);
1220 let optimized_plan = optimized_plan.data;
1221
1222 let schema = optimized_plan.schema();
1223 let fields_with_datatypes: Vec<_> = schema
1224 .fields()
1225 .iter()
1226 .map(|field| (field.name(), field.data_type()))
1227 .collect();
1228 let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1229 let expected = r#"[
1230 (
1231 "a",
1232 UInt64,
1233 ),
1234 (
1235 "b",
1236 UInt64,
1237 ),
1238 (
1239 "c",
1240 UInt64,
1241 ),
1242]"#;
1243 assert_eq!(expected, formatted_fields_with_datatype);
1244 }
1245
1246 #[test]
1247 fn filter_schema_changed() -> Result<()> {
1248 let table_scan = test_table_scan()?;
1249
1250 let plan = LogicalPlanBuilder::from(table_scan)
1251 .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1252 .build()?;
1253
1254 assert_optimized_plan_equal!(
1255 plan,
1256 @ r"
1257 Projection: test.a, test.b, test.c
1258 Filter: __common_expr_1 - Int32(10) > __common_expr_1
1259 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1260 TableScan: test
1261 "
1262 )
1263 }
1264
1265 #[test]
1266 fn test_extract_expressions_from_grouping_set() -> Result<()> {
1267 let mut result = Vec::with_capacity(3);
1268 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1269 extract_expressions(&grouping, &mut result);
1270
1271 assert!(result.len() == 3);
1272 Ok(())
1273 }
1274
1275 #[test]
1276 fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1277 let mut result = Vec::with_capacity(2);
1278 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1279 extract_expressions(&grouping, &mut result);
1280 assert!(result.len() == 2);
1281 Ok(())
1282 }
1283
1284 #[test]
1285 fn test_alias_collision() -> Result<()> {
1286 let table_scan = test_table_scan()?;
1287
1288 let config = OptimizerContext::new();
1289 let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1290 let plan = LogicalPlanBuilder::from(table_scan.clone())
1291 .project(vec![
1292 (col("a") + col("b")).alias(common_expr_1.clone()),
1293 col("c"),
1294 ])?
1295 .project(vec![
1296 col(common_expr_1.clone()).alias("c1"),
1297 col(common_expr_1).alias("c2"),
1298 (col("c") + lit(2)).alias("c3"),
1299 (col("c") + lit(2)).alias("c4"),
1300 ])?
1301 .build()?;
1302
1303 assert_optimized_plan_equal!(
1304 config,
1305 plan,
1306 @ r"
1307 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1308 Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1309 Projection: test.a + test.b AS __common_expr_1, test.c
1310 TableScan: test
1311 "
1312 )?;
1313
1314 let config = OptimizerContext::new();
1315 let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1316 let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1317 let plan = LogicalPlanBuilder::from(table_scan)
1318 .project(vec![
1319 (col("a") + col("b")).alias(common_expr_2.clone()),
1320 col("c"),
1321 ])?
1322 .project(vec![
1323 col(common_expr_2.clone()).alias("c1"),
1324 col(common_expr_2).alias("c2"),
1325 (col("c") + lit(2)).alias("c3"),
1326 (col("c") + lit(2)).alias("c4"),
1327 ])?
1328 .build()?;
1329
1330 assert_optimized_plan_equal!(
1331 config,
1332 plan,
1333 @ r"
1334 Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1335 Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1336 Projection: test.a + test.b AS __common_expr_2, test.c
1337 TableScan: test
1338 "
1339 )?;
1340
1341 Ok(())
1342 }
1343
1344 #[test]
1345 fn test_extract_expressions_from_col() -> Result<()> {
1346 let mut result = Vec::with_capacity(1);
1347 extract_expressions(&col("a"), &mut result);
1348 assert!(result.len() == 1);
1349 Ok(())
1350 }
1351
1352 #[test]
1353 fn test_short_circuits() -> Result<()> {
1354 let table_scan = test_table_scan()?;
1355
1356 let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1357 let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1358 let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1359 let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1360 let plan = LogicalPlanBuilder::from(table_scan)
1361 .project(vec![
1362 extracted_short_circuit.clone().alias("c1"),
1363 extracted_short_circuit.alias("c2"),
1364 extracted_short_circuit_leg_1
1365 .clone()
1366 .or(not_extracted_short_circuit_leg_2.clone())
1367 .alias("c3"),
1368 extracted_short_circuit_leg_1
1369 .and(not_extracted_short_circuit_leg_2)
1370 .alias("c4"),
1371 extracted_short_circuit_leg_3
1372 .clone()
1373 .or(extracted_short_circuit_leg_3)
1374 .alias("c5"),
1375 ])?
1376 .build()?;
1377
1378 assert_optimized_plan_equal!(
1379 plan,
1380 @ r"
1381 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1382 Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1383 TableScan: test
1384 "
1385 )
1386 }
1387
1388 #[test]
1389 fn test_volatile() -> Result<()> {
1390 let table_scan = test_table_scan()?;
1391
1392 let extracted_child = col("a") + col("b");
1393 let rand = rand_func().call(vec![]);
1394 let not_extracted_volatile = extracted_child + rand;
1395 let plan = LogicalPlanBuilder::from(table_scan)
1396 .project(vec![
1397 not_extracted_volatile.clone().alias("c1"),
1398 not_extracted_volatile.alias("c2"),
1399 ])?
1400 .build()?;
1401
1402 assert_optimized_plan_equal!(
1403 plan,
1404 @ r"
1405 Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1406 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1407 TableScan: test
1408 "
1409 )
1410 }
1411
1412 #[test]
1413 fn test_volatile_short_circuits() -> Result<()> {
1414 let table_scan = test_table_scan()?;
1415
1416 let rand = rand_func().call(vec![]);
1417 let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1418 let not_extracted_volatile_short_circuit_1 =
1419 extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1420 let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1421 let not_extracted_volatile_short_circuit_2 =
1422 rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1423 let plan = LogicalPlanBuilder::from(table_scan)
1424 .project(vec![
1425 not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1426 not_extracted_volatile_short_circuit_1.alias("c2"),
1427 not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1428 not_extracted_volatile_short_circuit_2.alias("c4"),
1429 ])?
1430 .build()?;
1431
1432 assert_optimized_plan_equal!(
1433 plan,
1434 @ r"
1435 Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1436 Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1437 TableScan: test
1438 "
1439 )
1440 }
1441
1442 #[test]
1443 fn test_non_top_level_common_expression() -> Result<()> {
1444 let table_scan = test_table_scan()?;
1445
1446 let common_expr = col("a") + col("b");
1447 let plan = LogicalPlanBuilder::from(table_scan)
1448 .project(vec![
1449 common_expr.clone().alias("c1"),
1450 common_expr.alias("c2"),
1451 ])?
1452 .project(vec![col("c1"), col("c2")])?
1453 .build()?;
1454
1455 assert_optimized_plan_equal!(
1456 plan,
1457 @ r"
1458 Projection: c1, c2
1459 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1460 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1461 TableScan: test
1462 "
1463 )
1464 }
1465
1466 #[test]
1467 fn test_nested_common_expression() -> Result<()> {
1468 let table_scan = test_table_scan()?;
1469
1470 let nested_common_expr = col("a") + col("b");
1471 let common_expr = nested_common_expr.clone() * nested_common_expr;
1472 let plan = LogicalPlanBuilder::from(table_scan)
1473 .project(vec![
1474 common_expr.clone().alias("c1"),
1475 common_expr.alias("c2"),
1476 ])?
1477 .build()?;
1478
1479 assert_optimized_plan_equal!(
1480 plan,
1481 @ r"
1482 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1483 Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1484 Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1485 TableScan: test
1486 "
1487 )
1488 }
1489
1490 #[test]
1491 fn test_normalize_add_expression() -> Result<()> {
1492 let table_scan = test_table_scan()?;
1494 let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1495 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1496
1497 assert_optimized_plan_equal!(
1498 plan,
1499 @ r"
1500 Projection: test.a, test.b, test.c
1501 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1502 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1503 TableScan: test
1504 "
1505 )
1506 }
1507
1508 #[test]
1509 fn test_normalize_multi_expression() -> Result<()> {
1510 let table_scan = test_table_scan()?;
1512 let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1513 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1514
1515 assert_optimized_plan_equal!(
1516 plan,
1517 @ r"
1518 Projection: test.a, test.b, test.c
1519 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1520 Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1521 TableScan: test
1522 "
1523 )
1524 }
1525
1526 #[test]
1527 fn test_normalize_bitset_and_expression() -> Result<()> {
1528 let table_scan = test_table_scan()?;
1530 let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1531 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1532
1533 assert_optimized_plan_equal!(
1534 plan,
1535 @ r"
1536 Projection: test.a, test.b, test.c
1537 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1538 Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1539 TableScan: test
1540 "
1541 )
1542 }
1543
1544 #[test]
1545 fn test_normalize_bitset_or_expression() -> Result<()> {
1546 let table_scan = test_table_scan()?;
1548 let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1549 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1550
1551 assert_optimized_plan_equal!(
1552 plan,
1553 @ r"
1554 Projection: test.a, test.b, test.c
1555 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1556 Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1557 TableScan: test
1558 "
1559 )
1560 }
1561
1562 #[test]
1563 fn test_normalize_bitset_xor_expression() -> Result<()> {
1564 let table_scan = test_table_scan()?;
1566 let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1567 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1568
1569 assert_optimized_plan_equal!(
1570 plan,
1571 @ r"
1572 Projection: test.a, test.b, test.c
1573 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1574 Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1575 TableScan: test
1576 "
1577 )
1578 }
1579
1580 #[test]
1581 fn test_normalize_eq_expression() -> Result<()> {
1582 let table_scan = test_table_scan()?;
1584 let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1585 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1586
1587 assert_optimized_plan_equal!(
1588 plan,
1589 @ r"
1590 Projection: test.a, test.b, test.c
1591 Filter: __common_expr_1 AND __common_expr_1
1592 Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1593 TableScan: test
1594 "
1595 )
1596 }
1597
1598 #[test]
1599 fn test_normalize_ne_expression() -> Result<()> {
1600 let table_scan = test_table_scan()?;
1602 let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1603 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1604
1605 assert_optimized_plan_equal!(
1606 plan,
1607 @ r"
1608 Projection: test.a, test.b, test.c
1609 Filter: __common_expr_1 AND __common_expr_1
1610 Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1611 TableScan: test
1612 "
1613 )
1614 }
1615
1616 #[test]
1617 fn test_normalize_complex_expression() -> Result<()> {
1618 let table_scan = test_table_scan()?;
1620 let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1621 .eq(lit(30));
1622 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1623
1624 assert_optimized_plan_equal!(
1625 plan,
1626 @ r"
1627 Projection: test.a, test.b, test.c
1628 Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1629 Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1630 TableScan: test
1631 "
1632 )?;
1633
1634 let table_scan = test_table_scan()?;
1636 let expr = (((col("a") + col("b") / col("c")) * col("c"))
1637 / (col("c") * (col("b") / col("c") + col("a")))
1638 + col("a"))
1639 .eq(lit(30));
1640 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1641
1642 assert_optimized_plan_equal!(
1643 plan,
1644 @ r"
1645 Projection: test.a, test.b, test.c
1646 Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1647 Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1648 TableScan: test
1649 "
1650 )?;
1651
1652 let table_scan = test_table_scan()?;
1654 let expr = ((col("b") / (col("a") + col("c")))
1655 * (col("b") / (col("c") + col("a"))))
1656 .eq(lit(30));
1657 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1658 assert_optimized_plan_equal!(
1659 plan,
1660 @ r"
1661 Projection: test.a, test.b, test.c
1662 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1663 Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1664 TableScan: test
1665 "
1666 )?;
1667
1668 Ok(())
1669 }
1670
1671 #[derive(Debug, PartialEq, Eq, Hash)]
1672 pub struct TestUdf {
1673 signature: Signature,
1674 }
1675
1676 impl TestUdf {
1677 pub fn new() -> Self {
1678 Self {
1679 signature: Signature::numeric(1, Volatility::Immutable),
1680 }
1681 }
1682 }
1683
1684 impl ScalarUDFImpl for TestUdf {
1685 fn as_any(&self) -> &dyn Any {
1686 self
1687 }
1688 fn name(&self) -> &str {
1689 "my_udf"
1690 }
1691
1692 fn signature(&self) -> &Signature {
1693 &self.signature
1694 }
1695
1696 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1697 Ok(DataType::Int32)
1698 }
1699
1700 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1701 panic!("not implemented")
1702 }
1703 }
1704
1705 #[test]
1706 fn test_normalize_inner_binary_expression() -> Result<()> {
1707 let table_scan = test_table_scan()?;
1709 let expr1 = not(col("a").eq(col("b")));
1710 let expr2 = not(col("b").eq(col("a")));
1711 let plan = LogicalPlanBuilder::from(table_scan)
1712 .project(vec![expr1, expr2])?
1713 .build()?;
1714 assert_optimized_plan_equal!(
1715 plan,
1716 @ r"
1717 Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1718 Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1719 TableScan: test
1720 "
1721 )?;
1722
1723 let table_scan = test_table_scan()?;
1725 let expr1 = is_null(col("a").eq(col("b")));
1726 let expr2 = is_null(col("b").eq(col("a")));
1727 let plan = LogicalPlanBuilder::from(table_scan)
1728 .project(vec![expr1, expr2])?
1729 .build()?;
1730 assert_optimized_plan_equal!(
1731 plan,
1732 @ r"
1733 Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1734 Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1735 TableScan: test
1736 "
1737 )?;
1738
1739 let table_scan = test_table_scan()?;
1741 let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1742 let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1743 let plan = LogicalPlanBuilder::from(table_scan)
1744 .project(vec![expr1, expr2])?
1745 .build()?;
1746 assert_optimized_plan_equal!(
1747 plan,
1748 @ r"
1749 Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1750 Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1751 TableScan: test
1752 "
1753 )?;
1754
1755 let table_scan = test_table_scan()?;
1757 let expr1 = col("c").between(col("a") + col("b"), lit(10));
1758 let expr2 = col("c").between(col("b") + col("a"), lit(10));
1759 let plan = LogicalPlanBuilder::from(table_scan)
1760 .project(vec![expr1, expr2])?
1761 .build()?;
1762 assert_optimized_plan_equal!(
1763 plan,
1764 @ r"
1765 Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1766 Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1767 TableScan: test
1768 "
1769 )?;
1770
1771 let udf = ScalarUDF::from(TestUdf::new());
1773 let table_scan = test_table_scan()?;
1774 let expr1 = udf.call(vec![col("a") + col("b")]);
1775 let expr2 = udf.call(vec![col("b") + col("a")]);
1776 let plan = LogicalPlanBuilder::from(table_scan)
1777 .project(vec![expr1, expr2])?
1778 .build()?;
1779 assert_optimized_plan_equal!(
1780 plan,
1781 @ r"
1782 Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1783 Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1784 TableScan: test
1785 "
1786 )
1787 }
1788
1789 fn rand_func() -> ScalarUDF {
1795 ScalarUDF::new_from_impl(RandomStub::new())
1796 }
1797
1798 #[derive(Debug, PartialEq, Eq, Hash)]
1799 struct RandomStub {
1800 signature: Signature,
1801 }
1802
1803 impl RandomStub {
1804 fn new() -> Self {
1805 Self {
1806 signature: Signature::exact(vec![], Volatility::Volatile),
1807 }
1808 }
1809 }
1810 impl ScalarUDFImpl for RandomStub {
1811 fn as_any(&self) -> &dyn Any {
1812 self
1813 }
1814
1815 fn name(&self) -> &str {
1816 "random"
1817 }
1818
1819 fn signature(&self) -> &Signature {
1820 &self.signature
1821 }
1822
1823 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1824 Ok(DataType::Float64)
1825 }
1826
1827 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1828 panic!("dummy - not implemented")
1829 }
1830 }
1831}