1use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35 Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71 pub fn new() -> Self {
72 Self {}
73 }
74
75 fn try_optimize_proj(
76 &self,
77 projection: Projection,
78 config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 let Projection {
81 expr,
82 input,
83 schema,
84 ..
85 } = projection;
86 let input = Arc::unwrap_or_clone(input);
87 self.try_unary_plan(expr, input, config)?
88 .map_data(|(new_expr, new_input)| {
89 Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90 .map(LogicalPlan::Projection)
91 })
92 }
93
94 fn try_optimize_sort(
95 &self,
96 sort: Sort,
97 config: &dyn OptimizerConfig,
98 ) -> Result<Transformed<LogicalPlan>> {
99 let Sort { expr, input, fetch } = sort;
100 let input = Arc::unwrap_or_clone(input);
101 let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102 .into_iter()
103 .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104 .unzip();
105 let new_sort = self
106 .try_unary_plan(sort_expressions, input, config)?
107 .update_data(|(new_expr, new_input)| {
108 LogicalPlan::Sort(Sort {
109 expr: new_expr
110 .into_iter()
111 .zip(sort_params)
112 .map(|(expr, (asc, nulls_first))| SortExpr {
113 expr,
114 asc,
115 nulls_first,
116 })
117 .collect(),
118 input: Arc::new(new_input),
119 fetch,
120 })
121 });
122 Ok(new_sort)
123 }
124
125 fn try_optimize_filter(
126 &self,
127 filter: Filter,
128 config: &dyn OptimizerConfig,
129 ) -> Result<Transformed<LogicalPlan>> {
130 let Filter {
131 predicate, input, ..
132 } = filter;
133 let input = Arc::unwrap_or_clone(input);
134 let expr = vec![predicate];
135 self.try_unary_plan(expr, input, config)?
136 .map_data(|(mut new_expr, new_input)| {
137 assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
139 Filter::try_new(new_predicate, Arc::new(new_input))
140 .map(LogicalPlan::Filter)
141 })
142 }
143
144 fn try_optimize_window(
145 &self,
146 window: Window,
147 config: &dyn OptimizerConfig,
148 ) -> Result<Transformed<LogicalPlan>> {
149 let (window_expr_list, window_schemas, input) =
152 get_consecutive_window_exprs(window);
153
154 match CSE::new(ExprCSEController::new(
157 config.alias_generator().as_ref(),
158 ExprMask::Normal,
159 ))
160 .extract_common_nodes(window_expr_list)?
161 {
162 FoundCommonNodes::Yes {
166 common_nodes: common_exprs,
167 new_nodes_list: new_exprs_list,
168 original_nodes_list: original_exprs_list,
169 } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170 Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171 }),
172 FoundCommonNodes::No {
173 original_nodes_list: original_exprs_list,
174 } => Ok(Transformed::no((original_exprs_list, input, None))),
175 }?
176 .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179 self.rewrite(new_input, config)?.map_data(|new_input| {
180 Ok((new_window_expr_list, new_input, window_expr_list))
181 })
182 })?
183 .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185 if let Some(window_expr_list) = window_expr_list {
194 let name_preserver = NamePreserver::new_for_projection();
195 let saved_names = window_expr_list
196 .iter()
197 .map(|exprs| {
198 exprs
199 .iter()
200 .map(|expr| name_preserver.save(expr))
201 .collect::<Vec<_>>()
202 })
203 .collect::<Vec<_>>();
204 new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205 new_input,
206 |plan, (new_window_expr, saved_names)| {
207 let new_window_expr = new_window_expr
208 .into_iter()
209 .zip(saved_names)
210 .map(|(new_window_expr, saved_name)| {
211 saved_name.restore(new_window_expr)
212 })
213 .collect::<Vec<_>>();
214 Window::try_new(new_window_expr, Arc::new(plan))
215 .map(LogicalPlan::Window)
216 },
217 )
218 } else {
219 new_window_expr_list
220 .into_iter()
221 .zip(window_schemas)
222 .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223 Window::try_new_with_schema(
224 new_window_expr,
225 Arc::new(plan),
226 schema,
227 )
228 .map(LogicalPlan::Window)
229 })
230 }
231 })
232 }
233
234 fn try_optimize_aggregate(
235 &self,
236 aggregate: Aggregate,
237 config: &dyn OptimizerConfig,
238 ) -> Result<Transformed<LogicalPlan>> {
239 let Aggregate {
240 group_expr,
241 aggr_expr,
242 input,
243 schema,
244 ..
245 } = aggregate;
246 let input = Arc::unwrap_or_clone(input);
247 match CSE::new(ExprCSEController::new(
249 config.alias_generator().as_ref(),
250 ExprMask::Normal,
251 ))
252 .extract_common_nodes(vec![group_expr, aggr_expr])?
253 {
254 FoundCommonNodes::Yes {
258 common_nodes: common_exprs,
259 new_nodes_list: mut new_exprs_list,
260 original_nodes_list: mut original_exprs_list,
261 } => {
262 let new_aggr_expr = new_exprs_list.pop().unwrap();
263 let new_group_expr = new_exprs_list.pop().unwrap();
264
265 build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266 let aggr_expr = original_exprs_list.pop().unwrap();
267 Transformed::yes((
268 new_aggr_expr,
269 new_group_expr,
270 new_input,
271 Some(aggr_expr),
272 ))
273 })
274 }
275
276 FoundCommonNodes::No {
277 original_nodes_list: mut original_exprs_list,
278 } => {
279 let new_aggr_expr = original_exprs_list.pop().unwrap();
280 let new_group_expr = original_exprs_list.pop().unwrap();
281
282 Ok(Transformed::no((
283 new_aggr_expr,
284 new_group_expr,
285 input,
286 None,
287 )))
288 }
289 }?
290 .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293 self.rewrite(new_input, config)?.map_data(|new_input| {
294 Ok((
295 new_aggr_expr,
296 new_group_expr,
297 aggr_expr,
298 Arc::new(new_input),
299 ))
300 })
301 })?
302 .transform_data(
304 |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305 match CSE::new(ExprCSEController::new(
307 config.alias_generator().as_ref(),
308 ExprMask::NormalAndAggregates,
309 ))
310 .extract_common_nodes(vec![new_aggr_expr])?
311 {
312 FoundCommonNodes::Yes {
313 common_nodes: common_exprs,
314 new_nodes_list: mut new_exprs_list,
315 original_nodes_list: mut original_exprs_list,
316 } => {
317 let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318 let new_aggr_expr = original_exprs_list.pop().unwrap();
319
320 let mut agg_exprs = common_exprs
321 .into_iter()
322 .map(|(expr, expr_alias)| expr.alias(expr_alias))
323 .collect::<Vec<_>>();
324
325 let mut proj_exprs = vec![];
326 for expr in &new_group_expr {
327 extract_expressions(expr, &mut proj_exprs)
328 }
329 for (expr_rewritten, expr_orig) in
330 rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
331 {
332 if expr_rewritten == expr_orig {
333 if let Expr::Alias(Alias { expr, name, .. }) =
334 expr_rewritten
335 {
336 agg_exprs.push(expr.alias(&name));
337 proj_exprs
338 .push(Expr::Column(Column::from_name(name)));
339 } else {
340 let expr_alias =
341 config.alias_generator().next(CSE_PREFIX);
342 let (qualifier, field_name) =
343 expr_rewritten.qualified_name();
344 let out_name =
345 qualified_name(qualifier.as_ref(), &field_name);
346
347 agg_exprs.push(expr_rewritten.alias(&expr_alias));
348 proj_exprs.push(
349 Expr::Column(Column::from_name(expr_alias))
350 .alias(out_name),
351 );
352 }
353 } else {
354 proj_exprs.push(expr_rewritten);
355 }
356 }
357
358 let agg = LogicalPlan::Aggregate(Aggregate::try_new(
359 new_input,
360 new_group_expr,
361 agg_exprs,
362 )?);
363 Projection::try_new(proj_exprs, Arc::new(agg))
364 .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
365 }
366
367 FoundCommonNodes::No {
370 original_nodes_list: mut original_exprs_list,
371 } => {
372 let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
373
374 if let Some(aggr_expr) = aggr_expr {
385 let name_preserver = NamePreserver::new_for_projection();
386 let saved_names = aggr_expr
387 .iter()
388 .map(|expr| name_preserver.save(expr))
389 .collect::<Vec<_>>();
390 let new_aggr_expr = rewritten_aggr_expr
391 .into_iter()
392 .zip(saved_names)
393 .map(|(new_expr, saved_name)| {
394 saved_name.restore(new_expr)
395 })
396 .collect::<Vec<Expr>>();
397
398 Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
401 .map(LogicalPlan::Aggregate)
402 .map(Transformed::no)
403 } else {
404 Aggregate::try_new_with_schema(
405 new_input,
406 new_group_expr,
407 rewritten_aggr_expr,
408 schema,
409 )
410 .map(LogicalPlan::Aggregate)
411 .map(Transformed::no)
412 }
413 }
414 }
415 },
416 )
417 }
418
419 fn try_unary_plan(
434 &self,
435 exprs: Vec<Expr>,
436 input: LogicalPlan,
437 config: &dyn OptimizerConfig,
438 ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
439 match CSE::new(ExprCSEController::new(
441 config.alias_generator().as_ref(),
442 ExprMask::Normal,
443 ))
444 .extract_common_nodes(vec![exprs])?
445 {
446 FoundCommonNodes::Yes {
447 common_nodes: common_exprs,
448 new_nodes_list: mut new_exprs_list,
449 original_nodes_list: _,
450 } => {
451 let new_exprs = new_exprs_list.pop().unwrap();
452 build_common_expr_project_plan(input, common_exprs)
453 .map(|new_input| Transformed::yes((new_exprs, new_input)))
454 }
455 FoundCommonNodes::No {
456 original_nodes_list: mut original_exprs_list,
457 } => {
458 let new_exprs = original_exprs_list.pop().unwrap();
459 Ok(Transformed::no((new_exprs, input)))
460 }
461 }?
462 .transform_data(|(new_exprs, new_input)| {
465 self.rewrite(new_input, config)?
466 .map_data(|new_input| Ok((new_exprs, new_input)))
467 })
468 }
469}
470
471fn get_consecutive_window_exprs(
503 window: Window,
504) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
505 let mut window_expr_list = vec![];
506 let mut window_schemas = vec![];
507 let mut plan = LogicalPlan::Window(window);
508 while let LogicalPlan::Window(Window {
509 input,
510 window_expr,
511 schema,
512 }) = plan
513 {
514 window_expr_list.push(window_expr);
515 window_schemas.push(schema);
516
517 plan = Arc::unwrap_or_clone(input);
518 }
519 (window_expr_list, window_schemas, plan)
520}
521
522impl OptimizerRule for CommonSubexprEliminate {
523 fn supports_rewrite(&self) -> bool {
524 true
525 }
526
527 fn apply_order(&self) -> Option<ApplyOrder> {
528 None
532 }
533
534 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
535 fn rewrite(
536 &self,
537 plan: LogicalPlan,
538 config: &dyn OptimizerConfig,
539 ) -> Result<Transformed<LogicalPlan>> {
540 let original_schema = Arc::clone(plan.schema());
541
542 let optimized_plan = match plan {
543 LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
544 LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
545 LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
546 LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
547 LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
548 LogicalPlan::Join(_)
549 | LogicalPlan::Repartition(_)
550 | LogicalPlan::Union(_)
551 | LogicalPlan::TableScan(_)
552 | LogicalPlan::Values(_)
553 | LogicalPlan::EmptyRelation(_)
554 | LogicalPlan::Subquery(_)
555 | LogicalPlan::SubqueryAlias(_)
556 | LogicalPlan::Limit(_)
557 | LogicalPlan::Ddl(_)
558 | LogicalPlan::Explain(_)
559 | LogicalPlan::Analyze(_)
560 | LogicalPlan::Statement(_)
561 | LogicalPlan::DescribeTable(_)
562 | LogicalPlan::Distinct(_)
563 | LogicalPlan::Extension(_)
564 | LogicalPlan::Dml(_)
565 | LogicalPlan::Copy(_)
566 | LogicalPlan::Unnest(_)
567 | LogicalPlan::RecursiveQuery(_) => {
568 plan.map_children(|c| self.rewrite(c, config))?
571 }
572 };
573
574 if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
576 {
577 optimized_plan.map_data(|optimized_plan| {
578 build_recover_project_plan(&original_schema, optimized_plan)
579 })
580 } else {
581 Ok(optimized_plan)
582 }
583 }
584
585 fn name(&self) -> &str {
586 "common_sub_expression_eliminate"
587 }
588}
589
590#[derive(Debug, Clone, Copy)]
592enum ExprMask {
593 Normal,
602
603 NormalAndAggregates,
605}
606
607struct ExprCSEController<'a> {
608 alias_generator: &'a AliasGenerator,
609 mask: ExprMask,
610
611 alias_counter: usize,
613}
614
615impl<'a> ExprCSEController<'a> {
616 fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
617 Self {
618 alias_generator,
619 mask,
620 alias_counter: 0,
621 }
622 }
623}
624
625impl CSEController for ExprCSEController<'_> {
626 type Node = Expr;
627
628 fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
629 match node {
630 Expr::ScalarFunction(ScalarFunction { func, args })
634 if func.short_circuits() =>
635 {
636 Some((vec![], args.iter().collect()))
637 }
638
639 Expr::BinaryExpr(BinaryExpr {
642 left,
643 op: Operator::And | Operator::Or,
644 right,
645 }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
646
647 Expr::Case(Case {
651 expr,
652 when_then_expr,
653 else_expr,
654 }) => Some((
655 expr.iter()
656 .map(|e| e.as_ref())
657 .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
658 .collect(),
659 when_then_expr
660 .iter()
661 .take(1)
662 .map(|(_, then)| then.as_ref())
663 .chain(
664 when_then_expr
665 .iter()
666 .skip(1)
667 .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
668 )
669 .chain(else_expr.iter().map(|e| e.as_ref()))
670 .collect(),
671 )),
672 _ => None,
673 }
674 }
675
676 fn is_valid(node: &Expr) -> bool {
677 !node.is_volatile_node()
678 }
679
680 fn is_ignored(&self, node: &Expr) -> bool {
681 #[expect(deprecated)]
683 let is_normal_minus_aggregates = matches!(
684 node,
685 Expr::Literal(..)
686 | Expr::Column(..)
687 | Expr::ScalarVariable(..)
688 | Expr::Alias(..)
689 | Expr::Wildcard { .. }
690 );
691
692 let is_aggr = matches!(node, Expr::AggregateFunction(..));
693
694 match self.mask {
695 ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
696 ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
697 }
698 }
699
700 fn generate_alias(&self) -> String {
701 self.alias_generator.next(CSE_PREFIX)
702 }
703
704 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
705 if self.alias_counter > 0 {
707 col(alias)
708 } else {
709 self.alias_counter += 1;
710 col(alias).alias(node.schema_name().to_string())
711 }
712 }
713
714 fn rewrite_f_down(&mut self, node: &Expr) {
715 if matches!(node, Expr::Alias(_)) {
716 self.alias_counter += 1;
717 }
718 }
719 fn rewrite_f_up(&mut self, node: &Expr) {
720 if matches!(node, Expr::Alias(_)) {
721 self.alias_counter -= 1
722 }
723 }
724}
725
726impl Default for CommonSubexprEliminate {
727 fn default() -> Self {
728 Self::new()
729 }
730}
731
732fn build_common_expr_project_plan(
743 input: LogicalPlan,
744 common_exprs: Vec<(Expr, String)>,
745) -> Result<LogicalPlan> {
746 let mut fields_set = BTreeSet::new();
747 let mut project_exprs = common_exprs
748 .into_iter()
749 .map(|(expr, expr_alias)| {
750 fields_set.insert(expr_alias.clone());
751 Ok(expr.alias(expr_alias))
752 })
753 .collect::<Result<Vec<_>>>()?;
754
755 for (qualifier, field) in input.schema().iter() {
756 if fields_set.insert(qualified_name(qualifier, field.name())) {
757 project_exprs.push(Expr::from((qualifier, field)));
758 }
759 }
760
761 Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
762}
763
764fn build_recover_project_plan(
770 schema: &DFSchema,
771 input: LogicalPlan,
772) -> Result<LogicalPlan> {
773 let col_exprs = schema.iter().map(Expr::from).collect();
774 Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
775}
776
777fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
778 if let Expr::GroupingSet(groupings) = expr {
779 for e in groupings.distinct_expr() {
780 let (qualifier, field_name) = e.qualified_name();
781 let col = Column::new(qualifier, field_name);
782 result.push(Expr::Column(col))
783 }
784 } else {
785 let (qualifier, field_name) = expr.qualified_name();
786 let col = Column::new(qualifier, field_name);
787 result.push(Expr::Column(col));
788 }
789}
790
791#[cfg(test)]
792mod test {
793 use std::any::Any;
794 use std::iter;
795
796 use arrow::datatypes::{DataType, Field, Schema};
797 use datafusion_expr::logical_plan::{table_scan, JoinType};
798 use datafusion_expr::{
799 grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
800 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
801 SimpleAggregateUDF, Volatility,
802 };
803 use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
804
805 use super::*;
806 use crate::assert_optimized_plan_eq_snapshot;
807 use crate::optimizer::OptimizerContext;
808 use crate::test::*;
809 use datafusion_expr::test::function_stub::{avg, sum};
810
811 macro_rules! assert_optimized_plan_equal {
812 (
813 $config:expr,
814 $plan:expr,
815 @ $expected:literal $(,)?
816 ) => {{
817 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
818 assert_optimized_plan_eq_snapshot!(
819 $config,
820 rules,
821 $plan,
822 @ $expected,
823 )
824 }};
825
826 (
827 $plan:expr,
828 @ $expected:literal $(,)?
829 ) => {{
830 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
831 let optimizer_ctx = OptimizerContext::new();
832 assert_optimized_plan_eq_snapshot!(
833 optimizer_ctx,
834 rules,
835 $plan,
836 @ $expected,
837 )
838 }};
839 }
840
841 #[test]
842 fn tpch_q1_simplified() -> Result<()> {
843 let table_scan = test_table_scan()?;
852
853 let plan = LogicalPlanBuilder::from(table_scan)
854 .aggregate(
855 iter::empty::<Expr>(),
856 vec![
857 sum(col("a") * (lit(1) - col("b"))),
858 sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
859 ],
860 )?
861 .build()?;
862
863 assert_optimized_plan_equal!(
864 plan,
865 @ r"
866 Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
867 Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
868 TableScan: test
869 "
870 )
871 }
872
873 #[test]
874 fn nested_aliases() -> Result<()> {
875 let table_scan = test_table_scan()?;
876
877 let plan = LogicalPlanBuilder::from(table_scan)
878 .project(vec![
879 (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
880 col("a") + col("b"),
881 ])?
882 .build()?;
883
884 assert_optimized_plan_equal!(
885 plan,
886 @ r"
887 Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
888 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
889 TableScan: test
890 "
891 )
892 }
893
894 #[test]
895 fn aggregate() -> Result<()> {
896 let table_scan = test_table_scan()?;
897
898 let return_type = DataType::UInt32;
899 let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
900 let udf_agg = |inner: Expr| {
901 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
902 Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
903 "my_agg",
904 Signature::exact(vec![DataType::UInt32], Volatility::Stable),
905 return_type.clone(),
906 Arc::clone(&accumulator),
907 vec![Field::new("value", DataType::UInt32, true).into()],
908 ))),
909 vec![inner],
910 false,
911 None,
912 None,
913 None,
914 ))
915 };
916
917 let plan = LogicalPlanBuilder::from(table_scan.clone())
919 .aggregate(
920 iter::empty::<Expr>(),
921 vec![
922 avg(col("a")).alias("col1"),
924 avg(col("a")).alias("col2"),
925 avg(col("b")).alias("col3"),
927 avg(col("c")),
928 udf_agg(col("a")).alias("col4"),
930 udf_agg(col("a")).alias("col5"),
931 udf_agg(col("b")).alias("col6"),
933 udf_agg(col("c")),
934 ],
935 )?
936 .build()?;
937
938 assert_optimized_plan_equal!(
939 plan,
940 @ r"
941 Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
942 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
943 TableScan: test
944 "
945 )?;
946
947 let plan = LogicalPlanBuilder::from(table_scan.clone())
949 .aggregate(
950 iter::empty::<Expr>(),
951 vec![
952 lit(1) + avg(col("a")),
953 lit(1) - avg(col("a")),
954 lit(1) + udf_agg(col("a")),
955 lit(1) - udf_agg(col("a")),
956 ],
957 )?
958 .build()?;
959
960 assert_optimized_plan_equal!(
961 plan,
962 @ r"
963 Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
964 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
965 TableScan: test
966 "
967 )?;
968
969 let plan = LogicalPlanBuilder::from(table_scan.clone())
971 .aggregate(
972 iter::empty::<Expr>(),
973 vec![
974 avg(lit(1u32) + col("a")).alias("col1"),
975 udf_agg(lit(1u32) + col("a")).alias("col2"),
976 ],
977 )?
978 .build()?;
979
980 assert_optimized_plan_equal!(
981 plan,
982 @ r"
983 Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
984 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
985 TableScan: test
986 "
987 )?;
988
989 let plan = LogicalPlanBuilder::from(table_scan.clone())
991 .aggregate(
992 vec![lit(1u32) + col("a")],
993 vec![
994 avg(lit(1u32) + col("a")).alias("col1"),
995 udf_agg(lit(1u32) + col("a")).alias("col2"),
996 ],
997 )?
998 .build()?;
999
1000 assert_optimized_plan_equal!(
1001 plan,
1002 @ r"
1003 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1004 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1005 TableScan: test
1006 "
1007 )?;
1008
1009 let plan = LogicalPlanBuilder::from(table_scan)
1011 .aggregate(
1012 vec![lit(1u32) + col("a")],
1013 vec![
1014 (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1015 (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1016 avg(lit(1u32) + col("a")),
1017 (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1018 (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1019 udf_agg(lit(1u32) + col("a")),
1020 ],
1021 )?
1022 .build()?;
1023
1024 assert_optimized_plan_equal!(
1025 plan,
1026 @ r"
1027 Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1028 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1029 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1030 TableScan: test
1031 "
1032 )
1033 }
1034
1035 #[test]
1036 fn aggregate_with_relations_and_dots() -> Result<()> {
1037 let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1038 let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1039
1040 let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1041
1042 let plan = LogicalPlanBuilder::from(table_scan)
1043 .aggregate(
1044 vec![col_a.clone()],
1045 vec![
1046 (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1047 avg(lit(1u32) + col_a),
1048 ],
1049 )?
1050 .build()?;
1051
1052 assert_optimized_plan_equal!(
1053 plan,
1054 @ r"
1055 Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1056 Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1057 Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1058 TableScan: table.test
1059 "
1060 )
1061 }
1062
1063 #[test]
1064 fn subexpr_in_same_order() -> Result<()> {
1065 let table_scan = test_table_scan()?;
1066
1067 let plan = LogicalPlanBuilder::from(table_scan)
1068 .project(vec![
1069 (lit(1) + col("a")).alias("first"),
1070 (lit(1) + col("a")).alias("second"),
1071 ])?
1072 .build()?;
1073
1074 assert_optimized_plan_equal!(
1075 plan,
1076 @ r"
1077 Projection: __common_expr_1 AS first, __common_expr_1 AS second
1078 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1079 TableScan: test
1080 "
1081 )
1082 }
1083
1084 #[test]
1085 fn subexpr_in_different_order() -> Result<()> {
1086 let table_scan = test_table_scan()?;
1087
1088 let plan = LogicalPlanBuilder::from(table_scan)
1089 .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1090 .build()?;
1091
1092 assert_optimized_plan_equal!(
1093 plan,
1094 @ r"
1095 Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1096 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1097 TableScan: test
1098 "
1099 )
1100 }
1101
1102 #[test]
1103 fn cross_plans_subexpr() -> Result<()> {
1104 let table_scan = test_table_scan()?;
1105
1106 let plan = LogicalPlanBuilder::from(table_scan)
1107 .project(vec![lit(1) + col("a"), col("a")])?
1108 .project(vec![lit(1) + col("a")])?
1109 .build()?;
1110
1111 assert_optimized_plan_equal!(
1112 plan,
1113 @ r"
1114 Projection: Int32(1) + test.a
1115 Projection: Int32(1) + test.a, test.a
1116 TableScan: test
1117 "
1118 )
1119 }
1120
1121 #[test]
1122 fn redundant_project_fields() {
1123 let table_scan = test_table_scan().unwrap();
1124 let c_plus_a = col("c") + col("a");
1125 let b_plus_a = col("b") + col("a");
1126 let common_exprs_1 = vec![
1127 (c_plus_a, format!("{CSE_PREFIX}_1")),
1128 (b_plus_a, format!("{CSE_PREFIX}_2")),
1129 ];
1130 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1131 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1132 let common_exprs_2 = vec![
1133 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1134 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1135 ];
1136 let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1137 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1138
1139 let mut field_set = BTreeSet::new();
1140 for name in project_2.schema().field_names() {
1141 assert!(field_set.insert(name));
1142 }
1143 }
1144
1145 #[test]
1146 fn redundant_project_fields_join_input() {
1147 let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1148 let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1149 let join = LogicalPlanBuilder::from(table_scan_1)
1150 .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1151 .unwrap()
1152 .build()
1153 .unwrap();
1154 let c_plus_a = col("test1.c") + col("test1.a");
1155 let b_plus_a = col("test1.b") + col("test1.a");
1156 let common_exprs_1 = vec![
1157 (c_plus_a, format!("{CSE_PREFIX}_1")),
1158 (b_plus_a, format!("{CSE_PREFIX}_2")),
1159 ];
1160 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1161 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1162 let common_exprs_2 = vec![
1163 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1164 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1165 ];
1166 let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1167 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1168
1169 let mut field_set = BTreeSet::new();
1170 for name in project_2.schema().field_names() {
1171 assert!(field_set.insert(name));
1172 }
1173 }
1174
1175 #[test]
1176 fn eliminated_subexpr_datatype() {
1177 use datafusion_expr::cast;
1178
1179 let schema = Schema::new(vec![
1180 Field::new("a", DataType::UInt64, false),
1181 Field::new("b", DataType::UInt64, false),
1182 Field::new("c", DataType::UInt64, false),
1183 ]);
1184
1185 let plan = table_scan(Some("table"), &schema, None)
1186 .unwrap()
1187 .filter(
1188 cast(col("a"), DataType::Int64)
1189 .lt(lit(1_i64))
1190 .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1191 )
1192 .unwrap()
1193 .build()
1194 .unwrap();
1195 let rule = CommonSubexprEliminate::new();
1196 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1197 assert!(optimized_plan.transformed);
1198 let optimized_plan = optimized_plan.data;
1199
1200 let schema = optimized_plan.schema();
1201 let fields_with_datatypes: Vec<_> = schema
1202 .fields()
1203 .iter()
1204 .map(|field| (field.name(), field.data_type()))
1205 .collect();
1206 let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1207 let expected = r#"[
1208 (
1209 "a",
1210 UInt64,
1211 ),
1212 (
1213 "b",
1214 UInt64,
1215 ),
1216 (
1217 "c",
1218 UInt64,
1219 ),
1220]"#;
1221 assert_eq!(expected, formatted_fields_with_datatype);
1222 }
1223
1224 #[test]
1225 fn filter_schema_changed() -> Result<()> {
1226 let table_scan = test_table_scan()?;
1227
1228 let plan = LogicalPlanBuilder::from(table_scan)
1229 .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1230 .build()?;
1231
1232 assert_optimized_plan_equal!(
1233 plan,
1234 @ r"
1235 Projection: test.a, test.b, test.c
1236 Filter: __common_expr_1 - Int32(10) > __common_expr_1
1237 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1238 TableScan: test
1239 "
1240 )
1241 }
1242
1243 #[test]
1244 fn test_extract_expressions_from_grouping_set() -> Result<()> {
1245 let mut result = Vec::with_capacity(3);
1246 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1247 extract_expressions(&grouping, &mut result);
1248
1249 assert!(result.len() == 3);
1250 Ok(())
1251 }
1252
1253 #[test]
1254 fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1255 let mut result = Vec::with_capacity(2);
1256 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1257 extract_expressions(&grouping, &mut result);
1258 assert!(result.len() == 2);
1259 Ok(())
1260 }
1261
1262 #[test]
1263 fn test_alias_collision() -> Result<()> {
1264 let table_scan = test_table_scan()?;
1265
1266 let config = OptimizerContext::new();
1267 let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1268 let plan = LogicalPlanBuilder::from(table_scan.clone())
1269 .project(vec![
1270 (col("a") + col("b")).alias(common_expr_1.clone()),
1271 col("c"),
1272 ])?
1273 .project(vec![
1274 col(common_expr_1.clone()).alias("c1"),
1275 col(common_expr_1).alias("c2"),
1276 (col("c") + lit(2)).alias("c3"),
1277 (col("c") + lit(2)).alias("c4"),
1278 ])?
1279 .build()?;
1280
1281 assert_optimized_plan_equal!(
1282 config,
1283 plan,
1284 @ r"
1285 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1286 Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1287 Projection: test.a + test.b AS __common_expr_1, test.c
1288 TableScan: test
1289 "
1290 )?;
1291
1292 let config = OptimizerContext::new();
1293 let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1294 let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1295 let plan = LogicalPlanBuilder::from(table_scan)
1296 .project(vec![
1297 (col("a") + col("b")).alias(common_expr_2.clone()),
1298 col("c"),
1299 ])?
1300 .project(vec![
1301 col(common_expr_2.clone()).alias("c1"),
1302 col(common_expr_2).alias("c2"),
1303 (col("c") + lit(2)).alias("c3"),
1304 (col("c") + lit(2)).alias("c4"),
1305 ])?
1306 .build()?;
1307
1308 assert_optimized_plan_equal!(
1309 config,
1310 plan,
1311 @ r"
1312 Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1313 Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1314 Projection: test.a + test.b AS __common_expr_2, test.c
1315 TableScan: test
1316 "
1317 )?;
1318
1319 Ok(())
1320 }
1321
1322 #[test]
1323 fn test_extract_expressions_from_col() -> Result<()> {
1324 let mut result = Vec::with_capacity(1);
1325 extract_expressions(&col("a"), &mut result);
1326 assert!(result.len() == 1);
1327 Ok(())
1328 }
1329
1330 #[test]
1331 fn test_short_circuits() -> Result<()> {
1332 let table_scan = test_table_scan()?;
1333
1334 let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1335 let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1336 let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1337 let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1338 let plan = LogicalPlanBuilder::from(table_scan)
1339 .project(vec![
1340 extracted_short_circuit.clone().alias("c1"),
1341 extracted_short_circuit.alias("c2"),
1342 extracted_short_circuit_leg_1
1343 .clone()
1344 .or(not_extracted_short_circuit_leg_2.clone())
1345 .alias("c3"),
1346 extracted_short_circuit_leg_1
1347 .and(not_extracted_short_circuit_leg_2)
1348 .alias("c4"),
1349 extracted_short_circuit_leg_3
1350 .clone()
1351 .or(extracted_short_circuit_leg_3)
1352 .alias("c5"),
1353 ])?
1354 .build()?;
1355
1356 assert_optimized_plan_equal!(
1357 plan,
1358 @ r"
1359 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1360 Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1361 TableScan: test
1362 "
1363 )
1364 }
1365
1366 #[test]
1367 fn test_volatile() -> Result<()> {
1368 let table_scan = test_table_scan()?;
1369
1370 let extracted_child = col("a") + col("b");
1371 let rand = rand_func().call(vec![]);
1372 let not_extracted_volatile = extracted_child + rand;
1373 let plan = LogicalPlanBuilder::from(table_scan)
1374 .project(vec![
1375 not_extracted_volatile.clone().alias("c1"),
1376 not_extracted_volatile.alias("c2"),
1377 ])?
1378 .build()?;
1379
1380 assert_optimized_plan_equal!(
1381 plan,
1382 @ r"
1383 Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1384 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1385 TableScan: test
1386 "
1387 )
1388 }
1389
1390 #[test]
1391 fn test_volatile_short_circuits() -> Result<()> {
1392 let table_scan = test_table_scan()?;
1393
1394 let rand = rand_func().call(vec![]);
1395 let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1396 let not_extracted_volatile_short_circuit_1 =
1397 extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1398 let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1399 let not_extracted_volatile_short_circuit_2 =
1400 rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1401 let plan = LogicalPlanBuilder::from(table_scan)
1402 .project(vec![
1403 not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1404 not_extracted_volatile_short_circuit_1.alias("c2"),
1405 not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1406 not_extracted_volatile_short_circuit_2.alias("c4"),
1407 ])?
1408 .build()?;
1409
1410 assert_optimized_plan_equal!(
1411 plan,
1412 @ r"
1413 Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1414 Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1415 TableScan: test
1416 "
1417 )
1418 }
1419
1420 #[test]
1421 fn test_non_top_level_common_expression() -> Result<()> {
1422 let table_scan = test_table_scan()?;
1423
1424 let common_expr = col("a") + col("b");
1425 let plan = LogicalPlanBuilder::from(table_scan)
1426 .project(vec![
1427 common_expr.clone().alias("c1"),
1428 common_expr.alias("c2"),
1429 ])?
1430 .project(vec![col("c1"), col("c2")])?
1431 .build()?;
1432
1433 assert_optimized_plan_equal!(
1434 plan,
1435 @ r"
1436 Projection: c1, c2
1437 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1438 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1439 TableScan: test
1440 "
1441 )
1442 }
1443
1444 #[test]
1445 fn test_nested_common_expression() -> Result<()> {
1446 let table_scan = test_table_scan()?;
1447
1448 let nested_common_expr = col("a") + col("b");
1449 let common_expr = nested_common_expr.clone() * nested_common_expr;
1450 let plan = LogicalPlanBuilder::from(table_scan)
1451 .project(vec![
1452 common_expr.clone().alias("c1"),
1453 common_expr.alias("c2"),
1454 ])?
1455 .build()?;
1456
1457 assert_optimized_plan_equal!(
1458 plan,
1459 @ r"
1460 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1461 Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1462 Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1463 TableScan: test
1464 "
1465 )
1466 }
1467
1468 #[test]
1469 fn test_normalize_add_expression() -> Result<()> {
1470 let table_scan = test_table_scan()?;
1472 let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1473 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1474
1475 assert_optimized_plan_equal!(
1476 plan,
1477 @ r"
1478 Projection: test.a, test.b, test.c
1479 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1480 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1481 TableScan: test
1482 "
1483 )
1484 }
1485
1486 #[test]
1487 fn test_normalize_multi_expression() -> Result<()> {
1488 let table_scan = test_table_scan()?;
1490 let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1491 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1492
1493 assert_optimized_plan_equal!(
1494 plan,
1495 @ r"
1496 Projection: test.a, test.b, test.c
1497 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1498 Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1499 TableScan: test
1500 "
1501 )
1502 }
1503
1504 #[test]
1505 fn test_normalize_bitset_and_expression() -> Result<()> {
1506 let table_scan = test_table_scan()?;
1508 let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1509 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1510
1511 assert_optimized_plan_equal!(
1512 plan,
1513 @ r"
1514 Projection: test.a, test.b, test.c
1515 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1516 Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1517 TableScan: test
1518 "
1519 )
1520 }
1521
1522 #[test]
1523 fn test_normalize_bitset_or_expression() -> Result<()> {
1524 let table_scan = test_table_scan()?;
1526 let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1527 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1528
1529 assert_optimized_plan_equal!(
1530 plan,
1531 @ r"
1532 Projection: test.a, test.b, test.c
1533 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1534 Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1535 TableScan: test
1536 "
1537 )
1538 }
1539
1540 #[test]
1541 fn test_normalize_bitset_xor_expression() -> Result<()> {
1542 let table_scan = test_table_scan()?;
1544 let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1545 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1546
1547 assert_optimized_plan_equal!(
1548 plan,
1549 @ r"
1550 Projection: test.a, test.b, test.c
1551 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1552 Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1553 TableScan: test
1554 "
1555 )
1556 }
1557
1558 #[test]
1559 fn test_normalize_eq_expression() -> Result<()> {
1560 let table_scan = test_table_scan()?;
1562 let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1563 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1564
1565 assert_optimized_plan_equal!(
1566 plan,
1567 @ r"
1568 Projection: test.a, test.b, test.c
1569 Filter: __common_expr_1 AND __common_expr_1
1570 Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1571 TableScan: test
1572 "
1573 )
1574 }
1575
1576 #[test]
1577 fn test_normalize_ne_expression() -> Result<()> {
1578 let table_scan = test_table_scan()?;
1580 let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1581 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1582
1583 assert_optimized_plan_equal!(
1584 plan,
1585 @ r"
1586 Projection: test.a, test.b, test.c
1587 Filter: __common_expr_1 AND __common_expr_1
1588 Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1589 TableScan: test
1590 "
1591 )
1592 }
1593
1594 #[test]
1595 fn test_normalize_complex_expression() -> Result<()> {
1596 let table_scan = test_table_scan()?;
1598 let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1599 .eq(lit(30));
1600 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1601
1602 assert_optimized_plan_equal!(
1603 plan,
1604 @ r"
1605 Projection: test.a, test.b, test.c
1606 Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1607 Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1608 TableScan: test
1609 "
1610 )?;
1611
1612 let table_scan = test_table_scan()?;
1614 let expr = (((col("a") + col("b") / col("c")) * col("c"))
1615 / (col("c") * (col("b") / col("c") + col("a")))
1616 + col("a"))
1617 .eq(lit(30));
1618 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1619
1620 assert_optimized_plan_equal!(
1621 plan,
1622 @ r"
1623 Projection: test.a, test.b, test.c
1624 Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1625 Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1626 TableScan: test
1627 "
1628 )?;
1629
1630 let table_scan = test_table_scan()?;
1632 let expr = ((col("b") / (col("a") + col("c")))
1633 * (col("b") / (col("c") + col("a"))))
1634 .eq(lit(30));
1635 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1636 assert_optimized_plan_equal!(
1637 plan,
1638 @ r"
1639 Projection: test.a, test.b, test.c
1640 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1641 Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1642 TableScan: test
1643 "
1644 )?;
1645
1646 Ok(())
1647 }
1648
1649 #[derive(Debug)]
1650 pub struct TestUdf {
1651 signature: Signature,
1652 }
1653
1654 impl TestUdf {
1655 pub fn new() -> Self {
1656 Self {
1657 signature: Signature::numeric(1, Volatility::Immutable),
1658 }
1659 }
1660 }
1661
1662 impl ScalarUDFImpl for TestUdf {
1663 fn as_any(&self) -> &dyn Any {
1664 self
1665 }
1666 fn name(&self) -> &str {
1667 "my_udf"
1668 }
1669
1670 fn signature(&self) -> &Signature {
1671 &self.signature
1672 }
1673
1674 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1675 Ok(DataType::Int32)
1676 }
1677
1678 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1679 panic!("not implemented")
1680 }
1681 }
1682
1683 #[test]
1684 fn test_normalize_inner_binary_expression() -> Result<()> {
1685 let table_scan = test_table_scan()?;
1687 let expr1 = not(col("a").eq(col("b")));
1688 let expr2 = not(col("b").eq(col("a")));
1689 let plan = LogicalPlanBuilder::from(table_scan)
1690 .project(vec![expr1, expr2])?
1691 .build()?;
1692 assert_optimized_plan_equal!(
1693 plan,
1694 @ r"
1695 Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1696 Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1697 TableScan: test
1698 "
1699 )?;
1700
1701 let table_scan = test_table_scan()?;
1703 let expr1 = is_null(col("a").eq(col("b")));
1704 let expr2 = is_null(col("b").eq(col("a")));
1705 let plan = LogicalPlanBuilder::from(table_scan)
1706 .project(vec![expr1, expr2])?
1707 .build()?;
1708 assert_optimized_plan_equal!(
1709 plan,
1710 @ r"
1711 Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1712 Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1713 TableScan: test
1714 "
1715 )?;
1716
1717 let table_scan = test_table_scan()?;
1719 let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1720 let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1721 let plan = LogicalPlanBuilder::from(table_scan)
1722 .project(vec![expr1, expr2])?
1723 .build()?;
1724 assert_optimized_plan_equal!(
1725 plan,
1726 @ r"
1727 Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1728 Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1729 TableScan: test
1730 "
1731 )?;
1732
1733 let table_scan = test_table_scan()?;
1735 let expr1 = col("c").between(col("a") + col("b"), lit(10));
1736 let expr2 = col("c").between(col("b") + col("a"), lit(10));
1737 let plan = LogicalPlanBuilder::from(table_scan)
1738 .project(vec![expr1, expr2])?
1739 .build()?;
1740 assert_optimized_plan_equal!(
1741 plan,
1742 @ r"
1743 Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1744 Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1745 TableScan: test
1746 "
1747 )?;
1748
1749 let udf = ScalarUDF::from(TestUdf::new());
1751 let table_scan = test_table_scan()?;
1752 let expr1 = udf.call(vec![col("a") + col("b")]);
1753 let expr2 = udf.call(vec![col("b") + col("a")]);
1754 let plan = LogicalPlanBuilder::from(table_scan)
1755 .project(vec![expr1, expr2])?
1756 .build()?;
1757 assert_optimized_plan_equal!(
1758 plan,
1759 @ r"
1760 Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1761 Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1762 TableScan: test
1763 "
1764 )
1765 }
1766
1767 fn rand_func() -> ScalarUDF {
1773 ScalarUDF::new_from_impl(RandomStub::new())
1774 }
1775
1776 #[derive(Debug)]
1777 struct RandomStub {
1778 signature: Signature,
1779 }
1780
1781 impl RandomStub {
1782 fn new() -> Self {
1783 Self {
1784 signature: Signature::exact(vec![], Volatility::Volatile),
1785 }
1786 }
1787 }
1788 impl ScalarUDFImpl for RandomStub {
1789 fn as_any(&self) -> &dyn Any {
1790 self
1791 }
1792
1793 fn name(&self) -> &str {
1794 "random"
1795 }
1796
1797 fn signature(&self) -> &Signature {
1798 &self.signature
1799 }
1800
1801 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1802 Ok(DataType::Float64)
1803 }
1804
1805 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1806 panic!("dummy - not implemented")
1807 }
1808 }
1809}