1mod required_indices;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24use std::collections::HashSet;
25use std::sync::Arc;
26
27use datafusion_common::{
28 Column, DFSchema, HashMap, JoinType, Result, assert_eq_or_internal_err,
29 get_required_group_by_exprs_indices, internal_datafusion_err, internal_err,
30};
31use datafusion_expr::expr::Alias;
32use datafusion_expr::{
33 Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, Unnest, Window,
34 logical_plan::LogicalPlan,
35};
36
37use crate::optimize_projections::required_indices::RequiredIndices;
38use crate::utils::NamePreserver;
39use datafusion_common::tree_node::{
40 Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
41};
42
43#[derive(Default, Debug)]
77pub struct OptimizeProjections {}
78
79impl OptimizeProjections {
80 #[expect(missing_docs)]
81 pub fn new() -> Self {
82 Self {}
83 }
84}
85
86impl OptimizerRule for OptimizeProjections {
87 fn name(&self) -> &str {
88 "optimize_projections"
89 }
90
91 fn apply_order(&self) -> Option<ApplyOrder> {
92 None
93 }
94
95 fn supports_rewrite(&self) -> bool {
96 true
97 }
98
99 fn rewrite(
100 &self,
101 plan: LogicalPlan,
102 config: &dyn OptimizerConfig,
103 ) -> Result<Transformed<LogicalPlan>> {
104 let indices = RequiredIndices::new_for_all_exprs(&plan);
106 optimize_projections(plan, config, indices)
107 }
108}
109
110#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
130fn optimize_projections(
131 plan: LogicalPlan,
132 config: &dyn OptimizerConfig,
133 indices: RequiredIndices,
134) -> Result<Transformed<LogicalPlan>> {
135 match plan {
138 LogicalPlan::Projection(proj) => {
139 return merge_consecutive_projections(proj)?.transform_data(|proj| {
140 rewrite_projection_given_requirements(proj, config, &indices)
141 });
142 }
143 LogicalPlan::Aggregate(aggregate) => {
144 let n_group_exprs = aggregate.group_expr_len()?;
146 let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
149
150 let group_by_expr_existing = aggregate
152 .group_expr
153 .iter()
154 .map(|group_by_expr| group_by_expr.schema_name().to_string())
155 .collect::<Vec<_>>();
156
157 let new_group_bys = if let Some(simplest_groupby_indices) =
158 get_required_group_by_exprs_indices(
159 aggregate.input.schema(),
160 &group_by_expr_existing,
161 ) {
162 group_by_reqs
166 .append(&simplest_groupby_indices)
167 .get_at_indices(&aggregate.group_expr)
168 } else {
169 aggregate.group_expr
170 };
171
172 let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr);
175
176 if new_group_bys.is_empty() && new_aggr_expr.is_empty() {
177 return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
179 EmptyRelation {
180 produce_one_row: true,
181 schema: Arc::new(DFSchema::empty()),
182 },
183 )));
184 }
185
186 let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
187 let schema = aggregate.input.schema();
188 let necessary_indices =
189 RequiredIndices::new().with_exprs(schema, all_exprs_iter);
190 let necessary_exprs = necessary_indices.get_required_exprs(schema);
191
192 return optimize_projections(
193 Arc::unwrap_or_clone(aggregate.input),
194 config,
195 necessary_indices,
196 )?
197 .transform_data(|aggregate_input| {
198 add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)
203 })?
204 .map_data(|aggregate_input| {
205 Aggregate::try_new(
208 Arc::new(aggregate_input),
209 new_group_bys,
210 new_aggr_expr,
211 )
212 .map(LogicalPlan::Aggregate)
213 });
214 }
215 LogicalPlan::Window(window) => {
216 let input_schema = Arc::clone(window.input.schema());
217 let n_input_fields = input_schema.fields().len();
219 let (child_reqs, window_reqs) = indices.split_off(n_input_fields);
222
223 let new_window_expr = window_reqs.get_at_indices(&window.window_expr);
226
227 let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
230
231 return optimize_projections(
232 Arc::unwrap_or_clone(window.input),
233 config,
234 required_indices.clone(),
235 )?
236 .transform_data(|window_child| {
237 if new_window_expr.is_empty() {
238 Ok(Transformed::no(window_child))
240 } else {
241 let required_exprs =
245 required_indices.get_required_exprs(&input_schema);
246 let window_child =
247 add_projection_on_top_if_helpful(window_child, required_exprs)?
248 .data;
249 Window::try_new(new_window_expr, Arc::new(window_child))
250 .map(LogicalPlan::Window)
251 .map(Transformed::yes)
252 }
253 });
254 }
255 LogicalPlan::TableScan(table_scan) => {
256 let TableScan {
257 table_name,
258 source,
259 projection,
260 filters,
261 fetch,
262 projected_schema: _,
263 } = table_scan;
264
265 let projection = match &projection {
268 Some(projection) => indices.into_mapped_indices(|idx| projection[idx]),
269 None => indices.into_inner(),
270 };
271 return TableScan::try_new(
272 table_name,
273 source,
274 Some(projection),
275 filters,
276 fetch,
277 )
278 .map(LogicalPlan::TableScan)
279 .map(Transformed::yes);
280 }
281 _ => {}
283 };
284
285 let mut child_required_indices: Vec<RequiredIndices> = match &plan {
288 LogicalPlan::Sort(_)
289 | LogicalPlan::Filter(_)
290 | LogicalPlan::Repartition(_)
291 | LogicalPlan::Union(_)
292 | LogicalPlan::SubqueryAlias(_)
293 | LogicalPlan::Distinct(Distinct::On(_)) => {
294 plan.inputs()
299 .into_iter()
300 .map(|input| {
301 indices
302 .clone()
303 .with_projection_beneficial()
304 .with_plan_exprs(&plan, input.schema())
305 })
306 .collect::<Result<_>>()?
307 }
308 LogicalPlan::Limit(_) => {
309 plan.inputs()
314 .into_iter()
315 .map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
316 .collect::<Result<_>>()?
317 }
318 LogicalPlan::Copy(_)
319 | LogicalPlan::Ddl(_)
320 | LogicalPlan::Dml(_)
321 | LogicalPlan::Explain(_)
322 | LogicalPlan::Analyze(_)
323 | LogicalPlan::Subquery(_)
324 | LogicalPlan::Statement(_)
325 | LogicalPlan::Distinct(Distinct::All(_)) => {
326 plan.inputs()
332 .into_iter()
333 .map(RequiredIndices::new_for_all_exprs)
334 .collect()
335 }
336 LogicalPlan::Extension(extension) => {
337 let Some(necessary_children_indices) =
338 extension.node.necessary_children_exprs(indices.indices())
339 else {
340 return Ok(Transformed::no(plan));
342 };
343 let children = extension.node.inputs();
344 assert_eq_or_internal_err!(
345 children.len(),
346 necessary_children_indices.len(),
347 "Inconsistent length between children and necessary children indices. \
348 Make sure `.necessary_children_exprs` implementation of the \
349 `UserDefinedLogicalNode` is consistent with actual children length \
350 for the node."
351 );
352 children
353 .into_iter()
354 .zip(necessary_children_indices)
355 .map(|(child, necessary_indices)| {
356 RequiredIndices::new_from_indices(necessary_indices)
357 .with_plan_exprs(&plan, child.schema())
358 })
359 .collect::<Result<Vec<_>>>()?
360 }
361 LogicalPlan::EmptyRelation(_)
362 | LogicalPlan::Values(_)
363 | LogicalPlan::DescribeTable(_) => {
364 return Ok(Transformed::no(plan));
366 }
367 LogicalPlan::RecursiveQuery(recursive) => {
368 if plan_contains_other_subqueries(
372 recursive.static_term.as_ref(),
373 &recursive.name,
374 ) || plan_contains_other_subqueries(
375 recursive.recursive_term.as_ref(),
376 &recursive.name,
377 ) {
378 return Ok(Transformed::no(plan));
379 }
380
381 plan.inputs()
382 .into_iter()
383 .map(|input| {
384 indices
385 .clone()
386 .with_projection_beneficial()
387 .with_plan_exprs(&plan, input.schema())
388 })
389 .collect::<Result<Vec<_>>>()?
390 }
391 LogicalPlan::Join(join) => {
392 let left_len = join.left.schema().fields().len();
393 let (left_req_indices, right_req_indices) =
394 split_join_requirements(left_len, indices, &join.join_type);
395 let left_indices =
396 left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
397 let right_indices =
398 right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
399 vec![
402 left_indices.with_projection_beneficial(),
403 right_indices.with_projection_beneficial(),
404 ]
405 }
406 LogicalPlan::Projection(_)
408 | LogicalPlan::Aggregate(_)
409 | LogicalPlan::Window(_)
410 | LogicalPlan::TableScan(_) => {
411 return internal_err!(
412 "OptimizeProjection: should have handled in the match statement above"
413 );
414 }
415 LogicalPlan::Unnest(Unnest {
416 input,
417 dependency_indices,
418 ..
419 }) => {
420 let required_indices =
422 RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
423
424 let mut additional_necessary_child_indices = Vec::new();
426 indices.indices().iter().for_each(|idx| {
427 if let Some(index) = dependency_indices.get(*idx) {
428 additional_necessary_child_indices.push(*index);
429 }
430 });
431 vec![required_indices.append(&additional_necessary_child_indices)]
432 }
433 };
434
435 child_required_indices.reverse();
438 assert_eq_or_internal_err!(
439 child_required_indices.len(),
440 plan.inputs().len(),
441 "OptimizeProjection: child_required_indices length mismatch with plan inputs"
442 );
443
444 let transformed_plan = plan.map_children(|child| {
446 let required_indices = child_required_indices.pop().ok_or_else(|| {
447 internal_datafusion_err!(
448 "Unexpected number of required_indices in OptimizeProjections rule"
449 )
450 })?;
451
452 let projection_beneficial = required_indices.projection_beneficial();
453 let project_exprs = required_indices.get_required_exprs(child.schema());
454
455 optimize_projections(child, config, required_indices)?.transform_data(
456 |new_input| {
457 if projection_beneficial {
458 add_projection_on_top_if_helpful(new_input, project_exprs)
459 } else {
460 Ok(Transformed::no(new_input))
461 }
462 },
463 )
464 })?;
465
466 if transformed_plan.transformed {
468 transformed_plan.map_data(|plan| plan.recompute_schema())
469 } else {
470 Ok(transformed_plan)
471 }
472}
473
474fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Projection>> {
507 let Projection {
508 expr,
509 input,
510 schema,
511 ..
512 } = proj;
513 let LogicalPlan::Projection(prev_projection) = input.as_ref() else {
514 return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
515 };
516
517 if prev_projection.expr == expr {
520 return Projection::try_new_with_schema(
521 expr,
522 Arc::clone(&prev_projection.input),
523 schema,
524 )
525 .map(Transformed::yes);
526 }
527
528 let mut column_referral_map = HashMap::<&Column, usize>::new();
530 expr.iter()
531 .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map));
532
533 if column_referral_map.into_iter().any(|(col, usage)| {
537 usage > 1
538 && !is_expr_trivial(
539 &prev_projection.expr
540 [prev_projection.schema.index_of_column(col).unwrap()],
541 )
542 }) {
543 return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
545 }
546
547 let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else {
548 unreachable!();
550 };
551
552 let name_preserver = NamePreserver::new_for_projection();
555 let mut original_names = vec![];
556 let new_exprs = expr.map_elements(|expr| {
557 original_names.push(name_preserver.save(&expr));
558
559 match expr {
561 Expr::Alias(Alias {
562 expr,
563 relation,
564 name,
565 metadata,
566 }) => rewrite_expr(*expr, &prev_projection).map(|result| {
567 result.update_data(|expr| {
568 Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata))
569 })
570 }),
571 e => rewrite_expr(e, &prev_projection),
572 }
573 })?;
574
575 if new_exprs.transformed {
578 let new_exprs = new_exprs
580 .data
581 .into_iter()
582 .zip(original_names)
583 .map(|(expr, original_name)| original_name.restore(expr))
584 .collect::<Vec<_>>();
585 Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
586 } else {
587 let input = Arc::new(LogicalPlan::Projection(prev_projection));
589 Projection::try_new_with_schema(new_exprs.data, input, schema)
590 .map(Transformed::no)
591 }
592}
593
594fn is_expr_trivial(expr: &Expr) -> bool {
596 matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
597}
598
599fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
645 expr.transform_up(|expr| {
646 match expr {
647 Expr::Alias(alias) => {
649 match alias
650 .metadata
651 .as_ref()
652 .map(|h| h.is_empty())
653 .unwrap_or(true)
654 {
655 true => Ok(Transformed::yes(*alias.expr)),
656 false => Ok(Transformed::no(Expr::Alias(alias))),
657 }
658 }
659 Expr::Column(col) => {
660 let idx = input.schema.index_of_column(&col)?;
662 let input_expr = input.expr[idx].clone().unalias_nested().data;
670 Ok(Transformed::yes(input_expr))
671 }
672 _ => Ok(Transformed::no(expr)),
674 }
675 })
676}
677
678fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
687 expr.apply(|expr| {
689 match expr {
690 Expr::OuterReferenceColumn(_, col) => {
691 columns.insert(col);
692 }
693 Expr::ScalarSubquery(subquery) => {
694 outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
695 }
696 Expr::Exists(exists) => {
697 outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
698 }
699 Expr::InSubquery(insubquery) => {
700 outer_columns_helper_multi(
701 &insubquery.subquery.outer_ref_columns,
702 columns,
703 );
704 }
705 _ => {}
706 };
707 Ok(TreeNodeRecursion::Continue)
708 })
709 .unwrap();
711}
712
713fn outer_columns_helper_multi<'a, 'b>(
722 exprs: impl IntoIterator<Item = &'a Expr>,
723 columns: &'b mut HashSet<&'a Column>,
724) {
725 exprs.into_iter().for_each(|e| outer_columns(e, columns));
726}
727
728fn split_join_requirements(
758 left_len: usize,
759 indices: RequiredIndices,
760 join_type: &JoinType,
761) -> (RequiredIndices, RequiredIndices) {
762 match join_type {
763 JoinType::Inner
765 | JoinType::Left
766 | JoinType::Right
767 | JoinType::Full
768 | JoinType::LeftMark
769 | JoinType::RightMark => {
770 indices.split_off(left_len)
773 }
774 JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
776 JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices),
779 }
780}
781
782fn add_projection_on_top_if_helpful(
800 plan: LogicalPlan,
801 project_exprs: Vec<Expr>,
802) -> Result<Transformed<LogicalPlan>> {
803 if project_exprs.len() >= plan.schema().fields().len() {
805 Ok(Transformed::no(plan))
806 } else {
807 Projection::try_new(project_exprs, Arc::new(plan))
808 .map(LogicalPlan::Projection)
809 .map(Transformed::yes)
810 }
811}
812
813fn rewrite_projection_given_requirements(
831 proj: Projection,
832 config: &dyn OptimizerConfig,
833 indices: &RequiredIndices,
834) -> Result<Transformed<LogicalPlan>> {
835 let Projection { expr, input, .. } = proj;
836
837 let exprs_used = indices.get_at_indices(&expr);
838
839 let required_indices =
840 RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
841
842 optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?
845 .transform_data(|input| {
846 if is_projection_unnecessary(&input, &exprs_used)? {
847 Ok(Transformed::yes(input))
848 } else {
849 Projection::try_new(exprs_used, Arc::new(input))
850 .map(LogicalPlan::Projection)
851 .map(Transformed::yes)
852 }
853 })
854}
855
856pub fn is_projection_unnecessary(
860 input: &LogicalPlan,
861 proj_exprs: &[Expr],
862) -> Result<bool> {
863 if proj_exprs.len() != input.schema().fields().len() {
865 return Ok(false);
866 }
867 Ok(input.schema().iter().zip(proj_exprs.iter()).all(
868 |((field_relation, field_name), expr)| {
869 if let Expr::Column(col) = expr {
871 col.relation.as_ref() == field_relation && col.name.eq(field_name.name())
872 } else {
873 false
874 }
875 },
876 ))
877}
878
879fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
885 if let LogicalPlan::SubqueryAlias(alias) = plan
886 && alias.alias.table() != cte_name
887 && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
888 {
889 return true;
890 }
891
892 let mut found = false;
893 plan.apply_expressions(|expr| {
894 if expr_contains_subquery(expr) {
895 found = true;
896 Ok(TreeNodeRecursion::Stop)
897 } else {
898 Ok(TreeNodeRecursion::Continue)
899 }
900 })
901 .expect("expression traversal never fails");
902 if found {
903 return true;
904 }
905
906 plan.inputs()
907 .into_iter()
908 .any(|child| plan_contains_other_subqueries(child, cte_name))
909}
910
911fn expr_contains_subquery(expr: &Expr) -> bool {
912 expr.exists(|e| match e {
913 Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
914 _ => Ok(false),
915 })
916 .unwrap()
918}
919
920fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool {
921 match plan {
922 LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name,
923 LogicalPlan::SubqueryAlias(alias) => {
924 subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
925 }
926 _ => {
927 let inputs = plan.inputs();
928 if inputs.len() == 1 {
929 subquery_alias_targets_recursive_cte(inputs[0], cte_name)
930 } else {
931 false
932 }
933 }
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use std::cmp::Ordering;
940 use std::collections::HashMap;
941 use std::fmt::Formatter;
942 use std::ops::Add;
943 use std::sync::Arc;
944 use std::vec;
945
946 use crate::optimize_projections::OptimizeProjections;
947 use crate::optimizer::Optimizer;
948 use crate::test::{
949 assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields,
950 test_table_scan_with_name,
951 };
952 use crate::{OptimizerContext, OptimizerRule};
953 use arrow::datatypes::{DataType, Field, Schema};
954 use datafusion_common::{
955 Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
956 };
957 use datafusion_expr::ExprFunctionExt;
958 use datafusion_expr::{
959 BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, Projection,
960 UserDefinedLogicalNodeCore, WindowFunctionDefinition, binary_expr,
961 build_join_schema,
962 builder::table_scan_with_filters,
963 col,
964 expr::{self, Cast},
965 lit,
966 logical_plan::{builder::LogicalPlanBuilder, table_scan},
967 not, try_cast, when,
968 };
969 use insta::assert_snapshot;
970
971 use crate::assert_optimized_plan_eq_snapshot;
972 use datafusion_functions_aggregate::count::count_udaf;
973 use datafusion_functions_aggregate::expr_fn::{count, max, min};
974 use datafusion_functions_aggregate::min_max::max_udaf;
975
976 macro_rules! assert_optimized_plan_equal {
977 (
978 $plan:expr,
979 @ $expected:literal $(,)?
980 ) => {{
981 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
982 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeProjections::new())];
983 assert_optimized_plan_eq_snapshot!(
984 optimizer_ctx,
985 rules,
986 $plan,
987 @ $expected,
988 )
989 }};
990 }
991
992 #[derive(Debug, Hash, PartialEq, Eq)]
993 struct NoOpUserDefined {
994 exprs: Vec<Expr>,
995 schema: DFSchemaRef,
996 input: Arc<LogicalPlan>,
997 }
998
999 impl NoOpUserDefined {
1000 fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
1001 Self {
1002 exprs: vec![],
1003 schema,
1004 input,
1005 }
1006 }
1007
1008 fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
1009 self.exprs = exprs;
1010 self
1011 }
1012 }
1013
1014 impl PartialOrd for NoOpUserDefined {
1016 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1017 match self.exprs.partial_cmp(&other.exprs) {
1018 Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
1019 cmp => cmp,
1020 }
1021 .filter(|cmp| *cmp != Ordering::Equal || self == other)
1023 }
1024 }
1025
1026 impl UserDefinedLogicalNodeCore for NoOpUserDefined {
1027 fn name(&self) -> &str {
1028 "NoOpUserDefined"
1029 }
1030
1031 fn inputs(&self) -> Vec<&LogicalPlan> {
1032 vec![&self.input]
1033 }
1034
1035 fn schema(&self) -> &DFSchemaRef {
1036 &self.schema
1037 }
1038
1039 fn expressions(&self) -> Vec<Expr> {
1040 self.exprs.clone()
1041 }
1042
1043 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1044 write!(f, "NoOpUserDefined")
1045 }
1046
1047 fn with_exprs_and_inputs(
1048 &self,
1049 exprs: Vec<Expr>,
1050 mut inputs: Vec<LogicalPlan>,
1051 ) -> Result<Self> {
1052 Ok(Self {
1053 exprs,
1054 input: Arc::new(inputs.swap_remove(0)),
1055 schema: Arc::clone(&self.schema),
1056 })
1057 }
1058
1059 fn necessary_children_exprs(
1060 &self,
1061 output_columns: &[usize],
1062 ) -> Option<Vec<Vec<usize>>> {
1063 Some(vec![output_columns.to_vec()])
1065 }
1066
1067 fn supports_limit_pushdown(&self) -> bool {
1068 false }
1070 }
1071
1072 #[derive(Debug, Hash, PartialEq, Eq)]
1073 struct UserDefinedCrossJoin {
1074 exprs: Vec<Expr>,
1075 schema: DFSchemaRef,
1076 left_child: Arc<LogicalPlan>,
1077 right_child: Arc<LogicalPlan>,
1078 }
1079
1080 impl UserDefinedCrossJoin {
1081 fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
1082 let left_schema = left_child.schema();
1083 let right_schema = right_child.schema();
1084 let schema = Arc::new(
1085 build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
1086 );
1087 Self {
1088 exprs: vec![],
1089 schema,
1090 left_child,
1091 right_child,
1092 }
1093 }
1094 }
1095
1096 impl PartialOrd for UserDefinedCrossJoin {
1098 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1099 match self.exprs.partial_cmp(&other.exprs) {
1100 Some(Ordering::Equal) => {
1101 match self.left_child.partial_cmp(&other.left_child) {
1102 Some(Ordering::Equal) => {
1103 self.right_child.partial_cmp(&other.right_child)
1104 }
1105 cmp => cmp,
1106 }
1107 }
1108 cmp => cmp,
1109 }
1110 .filter(|cmp| *cmp != Ordering::Equal || self == other)
1112 }
1113 }
1114
1115 impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
1116 fn name(&self) -> &str {
1117 "UserDefinedCrossJoin"
1118 }
1119
1120 fn inputs(&self) -> Vec<&LogicalPlan> {
1121 vec![&self.left_child, &self.right_child]
1122 }
1123
1124 fn schema(&self) -> &DFSchemaRef {
1125 &self.schema
1126 }
1127
1128 fn expressions(&self) -> Vec<Expr> {
1129 self.exprs.clone()
1130 }
1131
1132 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1133 write!(f, "UserDefinedCrossJoin")
1134 }
1135
1136 fn with_exprs_and_inputs(
1137 &self,
1138 exprs: Vec<Expr>,
1139 mut inputs: Vec<LogicalPlan>,
1140 ) -> Result<Self> {
1141 assert_eq!(inputs.len(), 2);
1142 Ok(Self {
1143 exprs,
1144 left_child: Arc::new(inputs.remove(0)),
1145 right_child: Arc::new(inputs.remove(0)),
1146 schema: Arc::clone(&self.schema),
1147 })
1148 }
1149
1150 fn necessary_children_exprs(
1151 &self,
1152 output_columns: &[usize],
1153 ) -> Option<Vec<Vec<usize>>> {
1154 let left_child_len = self.left_child.schema().fields().len();
1155 let mut left_reqs = vec![];
1156 let mut right_reqs = vec![];
1157 for &out_idx in output_columns {
1158 if out_idx < left_child_len {
1159 left_reqs.push(out_idx);
1160 } else {
1161 right_reqs.push(out_idx - left_child_len)
1164 }
1165 }
1166 Some(vec![left_reqs, right_reqs])
1167 }
1168
1169 fn supports_limit_pushdown(&self) -> bool {
1170 false }
1172 }
1173
1174 #[test]
1175 fn merge_two_projection() -> Result<()> {
1176 let table_scan = test_table_scan()?;
1177 let plan = LogicalPlanBuilder::from(table_scan)
1178 .project(vec![col("a")])?
1179 .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1180 .build()?;
1181
1182 assert_optimized_plan_equal!(
1183 plan,
1184 @r"
1185 Projection: Int32(1) + test.a
1186 TableScan: test projection=[a]
1187 "
1188 )
1189 }
1190
1191 #[test]
1192 fn merge_three_projection() -> Result<()> {
1193 let table_scan = test_table_scan()?;
1194 let plan = LogicalPlanBuilder::from(table_scan)
1195 .project(vec![col("a"), col("b")])?
1196 .project(vec![col("a")])?
1197 .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1198 .build()?;
1199
1200 assert_optimized_plan_equal!(
1201 plan,
1202 @r"
1203 Projection: Int32(1) + test.a
1204 TableScan: test projection=[a]
1205 "
1206 )
1207 }
1208
1209 #[test]
1210 fn merge_alias() -> Result<()> {
1211 let table_scan = test_table_scan()?;
1212 let plan = LogicalPlanBuilder::from(table_scan)
1213 .project(vec![col("a")])?
1214 .project(vec![col("a").alias("alias")])?
1215 .build()?;
1216
1217 assert_optimized_plan_equal!(
1218 plan,
1219 @r"
1220 Projection: test.a AS alias
1221 TableScan: test projection=[a]
1222 "
1223 )
1224 }
1225
1226 #[test]
1227 fn merge_nested_alias() -> Result<()> {
1228 let table_scan = test_table_scan()?;
1229 let plan = LogicalPlanBuilder::from(table_scan)
1230 .project(vec![col("a").alias("alias1").alias("alias2")])?
1231 .project(vec![col("alias2").alias("alias")])?
1232 .build()?;
1233
1234 assert_optimized_plan_equal!(
1235 plan,
1236 @r"
1237 Projection: test.a AS alias
1238 TableScan: test projection=[a]
1239 "
1240 )
1241 }
1242
1243 #[test]
1244 fn test_nested_count() -> Result<()> {
1245 let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
1246
1247 let groups: Vec<Expr> = vec![];
1248
1249 let plan = table_scan(TableReference::none(), &schema, None)
1250 .unwrap()
1251 .aggregate(groups.clone(), vec![count(lit(1))])
1252 .unwrap()
1253 .aggregate(groups, vec![count(lit(1))])
1254 .unwrap()
1255 .build()
1256 .unwrap();
1257
1258 assert_optimized_plan_equal!(
1259 plan,
1260 @r"
1261 Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1262 EmptyRelation: rows=1
1263 "
1264 )
1265 }
1266
1267 #[test]
1268 fn test_neg_push_down() -> Result<()> {
1269 let table_scan = test_table_scan()?;
1270 let plan = LogicalPlanBuilder::from(table_scan)
1271 .project(vec![-col("a")])?
1272 .build()?;
1273
1274 assert_optimized_plan_equal!(
1275 plan,
1276 @r"
1277 Projection: (- test.a)
1278 TableScan: test projection=[a]
1279 "
1280 )
1281 }
1282
1283 #[test]
1284 fn test_is_null() -> Result<()> {
1285 let table_scan = test_table_scan()?;
1286 let plan = LogicalPlanBuilder::from(table_scan)
1287 .project(vec![col("a").is_null()])?
1288 .build()?;
1289
1290 assert_optimized_plan_equal!(
1291 plan,
1292 @r"
1293 Projection: test.a IS NULL
1294 TableScan: test projection=[a]
1295 "
1296 )
1297 }
1298
1299 #[test]
1300 fn test_is_not_null() -> Result<()> {
1301 let table_scan = test_table_scan()?;
1302 let plan = LogicalPlanBuilder::from(table_scan)
1303 .project(vec![col("a").is_not_null()])?
1304 .build()?;
1305
1306 assert_optimized_plan_equal!(
1307 plan,
1308 @r"
1309 Projection: test.a IS NOT NULL
1310 TableScan: test projection=[a]
1311 "
1312 )
1313 }
1314
1315 #[test]
1316 fn test_is_true() -> Result<()> {
1317 let table_scan = test_table_scan()?;
1318 let plan = LogicalPlanBuilder::from(table_scan)
1319 .project(vec![col("a").is_true()])?
1320 .build()?;
1321
1322 assert_optimized_plan_equal!(
1323 plan,
1324 @r"
1325 Projection: test.a IS TRUE
1326 TableScan: test projection=[a]
1327 "
1328 )
1329 }
1330
1331 #[test]
1332 fn test_is_not_true() -> Result<()> {
1333 let table_scan = test_table_scan()?;
1334 let plan = LogicalPlanBuilder::from(table_scan)
1335 .project(vec![col("a").is_not_true()])?
1336 .build()?;
1337
1338 assert_optimized_plan_equal!(
1339 plan,
1340 @r"
1341 Projection: test.a IS NOT TRUE
1342 TableScan: test projection=[a]
1343 "
1344 )
1345 }
1346
1347 #[test]
1348 fn test_is_false() -> Result<()> {
1349 let table_scan = test_table_scan()?;
1350 let plan = LogicalPlanBuilder::from(table_scan)
1351 .project(vec![col("a").is_false()])?
1352 .build()?;
1353
1354 assert_optimized_plan_equal!(
1355 plan,
1356 @r"
1357 Projection: test.a IS FALSE
1358 TableScan: test projection=[a]
1359 "
1360 )
1361 }
1362
1363 #[test]
1364 fn test_is_not_false() -> Result<()> {
1365 let table_scan = test_table_scan()?;
1366 let plan = LogicalPlanBuilder::from(table_scan)
1367 .project(vec![col("a").is_not_false()])?
1368 .build()?;
1369
1370 assert_optimized_plan_equal!(
1371 plan,
1372 @r"
1373 Projection: test.a IS NOT FALSE
1374 TableScan: test projection=[a]
1375 "
1376 )
1377 }
1378
1379 #[test]
1380 fn test_is_unknown() -> Result<()> {
1381 let table_scan = test_table_scan()?;
1382 let plan = LogicalPlanBuilder::from(table_scan)
1383 .project(vec![col("a").is_unknown()])?
1384 .build()?;
1385
1386 assert_optimized_plan_equal!(
1387 plan,
1388 @r"
1389 Projection: test.a IS UNKNOWN
1390 TableScan: test projection=[a]
1391 "
1392 )
1393 }
1394
1395 #[test]
1396 fn test_is_not_unknown() -> Result<()> {
1397 let table_scan = test_table_scan()?;
1398 let plan = LogicalPlanBuilder::from(table_scan)
1399 .project(vec![col("a").is_not_unknown()])?
1400 .build()?;
1401
1402 assert_optimized_plan_equal!(
1403 plan,
1404 @r"
1405 Projection: test.a IS NOT UNKNOWN
1406 TableScan: test projection=[a]
1407 "
1408 )
1409 }
1410
1411 #[test]
1412 fn test_not() -> Result<()> {
1413 let table_scan = test_table_scan()?;
1414 let plan = LogicalPlanBuilder::from(table_scan)
1415 .project(vec![not(col("a"))])?
1416 .build()?;
1417
1418 assert_optimized_plan_equal!(
1419 plan,
1420 @r"
1421 Projection: NOT test.a
1422 TableScan: test projection=[a]
1423 "
1424 )
1425 }
1426
1427 #[test]
1428 fn test_try_cast() -> Result<()> {
1429 let table_scan = test_table_scan()?;
1430 let plan = LogicalPlanBuilder::from(table_scan)
1431 .project(vec![try_cast(col("a"), DataType::Float64)])?
1432 .build()?;
1433
1434 assert_optimized_plan_equal!(
1435 plan,
1436 @r"
1437 Projection: TRY_CAST(test.a AS Float64)
1438 TableScan: test projection=[a]
1439 "
1440 )
1441 }
1442
1443 #[test]
1444 fn test_similar_to() -> Result<()> {
1445 let table_scan = test_table_scan()?;
1446 let expr = Box::new(col("a"));
1447 let pattern = Box::new(lit("[0-9]"));
1448 let similar_to_expr =
1449 Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1450 let plan = LogicalPlanBuilder::from(table_scan)
1451 .project(vec![similar_to_expr])?
1452 .build()?;
1453
1454 assert_optimized_plan_equal!(
1455 plan,
1456 @r#"
1457 Projection: test.a SIMILAR TO Utf8("[0-9]")
1458 TableScan: test projection=[a]
1459 "#
1460 )
1461 }
1462
1463 #[test]
1464 fn test_between() -> Result<()> {
1465 let table_scan = test_table_scan()?;
1466 let plan = LogicalPlanBuilder::from(table_scan)
1467 .project(vec![col("a").between(lit(1), lit(3))])?
1468 .build()?;
1469
1470 assert_optimized_plan_equal!(
1471 plan,
1472 @r"
1473 Projection: test.a BETWEEN Int32(1) AND Int32(3)
1474 TableScan: test projection=[a]
1475 "
1476 )
1477 }
1478
1479 #[test]
1481 fn test_case_merged() -> Result<()> {
1482 let table_scan = test_table_scan()?;
1483 let plan = LogicalPlanBuilder::from(table_scan)
1484 .project(vec![col("a"), lit(0).alias("d")])?
1485 .project(vec![
1486 col("a"),
1487 when(col("a").eq(lit(1)), lit(10))
1488 .otherwise(col("d"))?
1489 .alias("d"),
1490 ])?
1491 .build()?;
1492
1493 assert_optimized_plan_equal!(
1494 plan,
1495 @r"
1496 Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d
1497 TableScan: test projection=[a]
1498 "
1499 )
1500 }
1501
1502 #[test]
1505 fn test_derived_column() -> Result<()> {
1506 let table_scan = test_table_scan()?;
1507 let plan = LogicalPlanBuilder::from(table_scan)
1508 .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])?
1509 .project(vec![
1510 col("a"),
1511 when(col("a").eq(lit(1)), lit(10))
1512 .otherwise(col("d"))?
1513 .alias("d"),
1514 ])?
1515 .build()?;
1516
1517 assert_optimized_plan_equal!(
1518 plan,
1519 @r"
1520 Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d
1521 Projection: test.a + Int32(1) AS a, Int32(0) AS d
1522 TableScan: test projection=[a]
1523 "
1524 )
1525 }
1526
1527 #[test]
1530 fn test_user_defined_logical_plan_node() -> Result<()> {
1531 let table_scan = test_table_scan()?;
1532 let custom_plan = LogicalPlan::Extension(Extension {
1533 node: Arc::new(NoOpUserDefined::new(
1534 Arc::clone(table_scan.schema()),
1535 Arc::new(table_scan.clone()),
1536 )),
1537 });
1538 let plan = LogicalPlanBuilder::from(custom_plan)
1539 .project(vec![col("a"), lit(0).alias("d")])?
1540 .build()?;
1541
1542 assert_optimized_plan_equal!(
1543 plan,
1544 @r"
1545 Projection: test.a, Int32(0) AS d
1546 NoOpUserDefined
1547 TableScan: test projection=[a]
1548 "
1549 )
1550 }
1551
1552 #[test]
1557 fn test_user_defined_logical_plan_node2() -> Result<()> {
1558 let table_scan = test_table_scan()?;
1559 let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
1560 let custom_plan = LogicalPlan::Extension(Extension {
1561 node: Arc::new(
1562 NoOpUserDefined::new(
1563 Arc::clone(table_scan.schema()),
1564 Arc::new(table_scan.clone()),
1565 )
1566 .with_exprs(exprs),
1567 ),
1568 });
1569 let plan = LogicalPlanBuilder::from(custom_plan)
1570 .project(vec![col("a"), lit(0).alias("d")])?
1571 .build()?;
1572
1573 assert_optimized_plan_equal!(
1574 plan,
1575 @r"
1576 Projection: test.a, Int32(0) AS d
1577 NoOpUserDefined
1578 TableScan: test projection=[a, b]
1579 "
1580 )
1581 }
1582
1583 #[test]
1589 fn test_user_defined_logical_plan_node3() -> Result<()> {
1590 let table_scan = test_table_scan()?;
1591 let left_expr = Expr::Column(Column::from_qualified_name("b"));
1592 let right_expr = Expr::Column(Column::from_qualified_name("c"));
1593 let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
1594 Box::new(left_expr),
1595 Operator::Plus,
1596 Box::new(right_expr),
1597 ));
1598 let exprs = vec![binary_expr];
1599 let custom_plan = LogicalPlan::Extension(Extension {
1600 node: Arc::new(
1601 NoOpUserDefined::new(
1602 Arc::clone(table_scan.schema()),
1603 Arc::new(table_scan.clone()),
1604 )
1605 .with_exprs(exprs),
1606 ),
1607 });
1608 let plan = LogicalPlanBuilder::from(custom_plan)
1609 .project(vec![col("a"), lit(0).alias("d")])?
1610 .build()?;
1611
1612 assert_optimized_plan_equal!(
1613 plan,
1614 @r"
1615 Projection: test.a, Int32(0) AS d
1616 NoOpUserDefined
1617 TableScan: test projection=[a, b, c]
1618 "
1619 )
1620 }
1621
1622 #[test]
1627 fn test_user_defined_logical_plan_node4() -> Result<()> {
1628 let left_table = test_table_scan_with_name("l")?;
1629 let right_table = test_table_scan_with_name("r")?;
1630 let custom_plan = LogicalPlan::Extension(Extension {
1631 node: Arc::new(UserDefinedCrossJoin::new(
1632 Arc::new(left_table),
1633 Arc::new(right_table),
1634 )),
1635 });
1636 let plan = LogicalPlanBuilder::from(custom_plan)
1637 .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
1638 .build()?;
1639
1640 assert_optimized_plan_equal!(
1641 plan,
1642 @r"
1643 Projection: l.a, l.c, r.a, Int32(0) AS d
1644 UserDefinedCrossJoin
1645 TableScan: l projection=[a, c]
1646 TableScan: r projection=[a]
1647 "
1648 )
1649 }
1650
1651 #[test]
1652 fn aggregate_no_group_by() -> Result<()> {
1653 let table_scan = test_table_scan()?;
1654
1655 let plan = LogicalPlanBuilder::from(table_scan)
1656 .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1657 .build()?;
1658
1659 assert_optimized_plan_equal!(
1660 plan,
1661 @r"
1662 Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1663 TableScan: test projection=[b]
1664 "
1665 )
1666 }
1667
1668 #[test]
1669 fn aggregate_group_by() -> Result<()> {
1670 let table_scan = test_table_scan()?;
1671
1672 let plan = LogicalPlanBuilder::from(table_scan)
1673 .aggregate(vec![col("c")], vec![max(col("b"))])?
1674 .build()?;
1675
1676 assert_optimized_plan_equal!(
1677 plan,
1678 @r"
1679 Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]
1680 TableScan: test projection=[b, c]
1681 "
1682 )
1683 }
1684
1685 #[test]
1686 fn aggregate_group_by_with_table_alias() -> Result<()> {
1687 let table_scan = test_table_scan()?;
1688
1689 let plan = LogicalPlanBuilder::from(table_scan)
1690 .alias("a")?
1691 .aggregate(vec![col("c")], vec![max(col("b"))])?
1692 .build()?;
1693
1694 assert_optimized_plan_equal!(
1695 plan,
1696 @r"
1697 Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]
1698 SubqueryAlias: a
1699 TableScan: test projection=[b, c]
1700 "
1701 )
1702 }
1703
1704 #[test]
1705 fn aggregate_no_group_by_with_filter() -> Result<()> {
1706 let table_scan = test_table_scan()?;
1707
1708 let plan = LogicalPlanBuilder::from(table_scan)
1709 .filter(col("c").gt(lit(1)))?
1710 .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1711 .build()?;
1712
1713 assert_optimized_plan_equal!(
1714 plan,
1715 @r"
1716 Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1717 Projection: test.b
1718 Filter: test.c > Int32(1)
1719 TableScan: test projection=[b, c]
1720 "
1721 )
1722 }
1723
1724 #[test]
1725 fn aggregate_with_periods() -> Result<()> {
1726 let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]);
1727
1728 let plan = table_scan(Some("m4"), &schema, None)?
1735 .aggregate(
1736 Vec::<Expr>::new(),
1737 vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")],
1738 )?
1739 .project([col(Column::new_unqualified("tag.one"))])?
1740 .build()?;
1741
1742 assert_optimized_plan_equal!(
1743 plan,
1744 @r"
1745 Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]
1746 TableScan: m4 projection=[tag.one]
1747 "
1748 )
1749 }
1750
1751 #[test]
1752 fn redundant_project() -> Result<()> {
1753 let table_scan = test_table_scan()?;
1754
1755 let plan = LogicalPlanBuilder::from(table_scan)
1756 .project(vec![col("a"), col("b"), col("c")])?
1757 .project(vec![col("a"), col("c"), col("b")])?
1758 .build()?;
1759 assert_optimized_plan_equal!(
1760 plan,
1761 @r"
1762 Projection: test.a, test.c, test.b
1763 TableScan: test projection=[a, b, c]
1764 "
1765 )
1766 }
1767
1768 #[test]
1769 fn reorder_scan() -> Result<()> {
1770 let schema = Schema::new(test_table_scan_fields());
1771
1772 let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?;
1773 assert_optimized_plan_equal!(
1774 plan,
1775 @"TableScan: test projection=[b, a, c]"
1776 )
1777 }
1778
1779 #[test]
1780 fn reorder_scan_projection() -> Result<()> {
1781 let schema = Schema::new(test_table_scan_fields());
1782
1783 let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?
1784 .project(vec![col("a"), col("b")])?
1785 .build()?;
1786 assert_optimized_plan_equal!(
1787 plan,
1788 @r"
1789 Projection: test.a, test.b
1790 TableScan: test projection=[b, a]
1791 "
1792 )
1793 }
1794
1795 #[test]
1796 fn reorder_projection() -> Result<()> {
1797 let table_scan = test_table_scan()?;
1798
1799 let plan = LogicalPlanBuilder::from(table_scan)
1800 .project(vec![col("c"), col("b"), col("a")])?
1801 .build()?;
1802 assert_optimized_plan_equal!(
1803 plan,
1804 @r"
1805 Projection: test.c, test.b, test.a
1806 TableScan: test projection=[a, b, c]
1807 "
1808 )
1809 }
1810
1811 #[test]
1812 fn noncontinuous_redundant_projection() -> Result<()> {
1813 let table_scan = test_table_scan()?;
1814
1815 let plan = LogicalPlanBuilder::from(table_scan)
1816 .project(vec![col("c"), col("b"), col("a")])?
1817 .filter(col("c").gt(lit(1)))?
1818 .project(vec![col("c"), col("a"), col("b")])?
1819 .filter(col("b").gt(lit(1)))?
1820 .filter(col("a").gt(lit(1)))?
1821 .project(vec![col("a"), col("c"), col("b")])?
1822 .build()?;
1823 assert_optimized_plan_equal!(
1824 plan,
1825 @r"
1826 Projection: test.a, test.c, test.b
1827 Filter: test.a > Int32(1)
1828 Filter: test.b > Int32(1)
1829 Projection: test.c, test.a, test.b
1830 Filter: test.c > Int32(1)
1831 Projection: test.c, test.b, test.a
1832 TableScan: test projection=[a, b, c]
1833 "
1834 )
1835 }
1836
1837 #[test]
1838 fn join_schema_trim_full_join_column_projection() -> Result<()> {
1839 let table_scan = test_table_scan()?;
1840
1841 let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1842 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1843
1844 let plan = LogicalPlanBuilder::from(table_scan)
1845 .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1846 .project(vec![col("a"), col("b"), col("c1")])?
1847 .build()?;
1848
1849 let optimized_plan = optimize(plan)?;
1850
1851 assert_snapshot!(
1853 optimized_plan.clone(),
1854 @r"
1855 Left Join: test.a = test2.c1
1856 TableScan: test projection=[a, b]
1857 TableScan: test2 projection=[c1]
1858 "
1859 );
1860
1861 let optimized_join = optimized_plan;
1863 assert_eq!(
1864 **optimized_join.schema(),
1865 DFSchema::new_with_metadata(
1866 vec![
1867 (
1868 Some("test".into()),
1869 Arc::new(Field::new("a", DataType::UInt32, false))
1870 ),
1871 (
1872 Some("test".into()),
1873 Arc::new(Field::new("b", DataType::UInt32, false))
1874 ),
1875 (
1876 Some("test2".into()),
1877 Arc::new(Field::new("c1", DataType::UInt32, true))
1878 ),
1879 ],
1880 HashMap::new()
1881 )?,
1882 );
1883
1884 Ok(())
1885 }
1886
1887 #[test]
1888 fn join_schema_trim_partial_join_column_projection() -> Result<()> {
1889 let table_scan = test_table_scan()?;
1892
1893 let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1894 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1895
1896 let plan = LogicalPlanBuilder::from(table_scan)
1897 .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1898 .project(vec![col("a"), col("b")])?
1901 .build()?;
1902
1903 let optimized_plan = optimize(plan)?;
1904
1905 assert_snapshot!(
1907 optimized_plan.clone(),
1908 @r"
1909 Projection: test.a, test.b
1910 Left Join: test.a = test2.c1
1911 TableScan: test projection=[a, b]
1912 TableScan: test2 projection=[c1]
1913 "
1914 );
1915
1916 let optimized_join = optimized_plan.inputs()[0];
1918 assert_eq!(
1919 **optimized_join.schema(),
1920 DFSchema::new_with_metadata(
1921 vec![
1922 (
1923 Some("test".into()),
1924 Arc::new(Field::new("a", DataType::UInt32, false))
1925 ),
1926 (
1927 Some("test".into()),
1928 Arc::new(Field::new("b", DataType::UInt32, false))
1929 ),
1930 (
1931 Some("test2".into()),
1932 Arc::new(Field::new("c1", DataType::UInt32, true))
1933 ),
1934 ],
1935 HashMap::new()
1936 )?,
1937 );
1938
1939 Ok(())
1940 }
1941
1942 #[test]
1943 fn join_schema_trim_using_join() -> Result<()> {
1944 let table_scan = test_table_scan()?;
1947
1948 let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1949 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1950
1951 let plan = LogicalPlanBuilder::from(table_scan)
1952 .join_using(table2_scan, JoinType::Left, vec!["a".into()])?
1953 .project(vec![col("a"), col("b")])?
1954 .build()?;
1955
1956 let optimized_plan = optimize(plan)?;
1957
1958 assert_snapshot!(
1960 optimized_plan.clone(),
1961 @r"
1962 Projection: test.a, test.b
1963 Left Join: Using test.a = test2.a
1964 TableScan: test projection=[a, b]
1965 TableScan: test2 projection=[a]
1966 "
1967 );
1968
1969 let optimized_join = optimized_plan.inputs()[0];
1971 assert_eq!(
1972 **optimized_join.schema(),
1973 DFSchema::new_with_metadata(
1974 vec![
1975 (
1976 Some("test".into()),
1977 Arc::new(Field::new("a", DataType::UInt32, false))
1978 ),
1979 (
1980 Some("test".into()),
1981 Arc::new(Field::new("b", DataType::UInt32, false))
1982 ),
1983 (
1984 Some("test2".into()),
1985 Arc::new(Field::new("a", DataType::UInt32, true))
1986 ),
1987 ],
1988 HashMap::new()
1989 )?,
1990 );
1991
1992 Ok(())
1993 }
1994
1995 #[test]
1996 fn cast() -> Result<()> {
1997 let table_scan = test_table_scan()?;
1998
1999 let plan = LogicalPlanBuilder::from(table_scan)
2000 .project(vec![Expr::Cast(Cast::new(
2001 Box::new(col("c")),
2002 DataType::Float64,
2003 ))])?
2004 .build()?;
2005
2006 assert_optimized_plan_equal!(
2007 plan,
2008 @r"
2009 Projection: CAST(test.c AS Float64)
2010 TableScan: test projection=[c]
2011 "
2012 )
2013 }
2014
2015 #[test]
2016 fn table_scan_projected_schema() -> Result<()> {
2017 let table_scan = test_table_scan()?;
2018 let plan = LogicalPlanBuilder::from(test_table_scan()?)
2019 .project(vec![col("a"), col("b")])?
2020 .build()?;
2021
2022 assert_eq!(3, table_scan.schema().fields().len());
2023 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2024 assert_fields_eq(&plan, vec!["a", "b"]);
2025
2026 assert_optimized_plan_equal!(
2027 plan,
2028 @"TableScan: test projection=[a, b]"
2029 )
2030 }
2031
2032 #[test]
2033 fn table_scan_projected_schema_non_qualified_relation() -> Result<()> {
2034 let table_scan = test_table_scan()?;
2035 let input_schema = table_scan.schema();
2036 assert_eq!(3, input_schema.fields().len());
2037 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2038
2039 let expr = vec![col("test.a"), col("test.b")];
2043 let plan =
2044 LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?);
2045
2046 assert_fields_eq(&plan, vec!["a", "b"]);
2047
2048 assert_optimized_plan_equal!(
2049 plan,
2050 @"TableScan: test projection=[a, b]"
2051 )
2052 }
2053
2054 #[test]
2055 fn table_limit() -> Result<()> {
2056 let table_scan = test_table_scan()?;
2057 assert_eq!(3, table_scan.schema().fields().len());
2058 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2059
2060 let plan = LogicalPlanBuilder::from(table_scan)
2061 .project(vec![col("c"), col("a")])?
2062 .limit(0, Some(5))?
2063 .build()?;
2064
2065 assert_fields_eq(&plan, vec!["c", "a"]);
2066
2067 assert_optimized_plan_equal!(
2068 plan,
2069 @r"
2070 Limit: skip=0, fetch=5
2071 Projection: test.c, test.a
2072 TableScan: test projection=[a, c]
2073 "
2074 )
2075 }
2076
2077 #[test]
2078 fn table_scan_without_projection() -> Result<()> {
2079 let table_scan = test_table_scan()?;
2080 let plan = LogicalPlanBuilder::from(table_scan).build()?;
2081 assert_optimized_plan_equal!(
2083 plan,
2084 @"TableScan: test projection=[a, b, c]"
2085 )
2086 }
2087
2088 #[test]
2089 fn table_scan_with_literal_projection() -> Result<()> {
2090 let table_scan = test_table_scan()?;
2091 let plan = LogicalPlanBuilder::from(table_scan)
2092 .project(vec![lit(1_i64), lit(2_i64)])?
2093 .build()?;
2094 assert_optimized_plan_equal!(
2095 plan,
2096 @r"
2097 Projection: Int64(1), Int64(2)
2098 TableScan: test projection=[]
2099 "
2100 )
2101 }
2102
2103 #[test]
2105 fn table_unused_column() -> Result<()> {
2106 let table_scan = test_table_scan()?;
2107 assert_eq!(3, table_scan.schema().fields().len());
2108 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2109
2110 let plan = LogicalPlanBuilder::from(table_scan)
2112 .project(vec![col("c"), col("a"), col("b")])?
2113 .filter(col("c").gt(lit(1)))?
2114 .aggregate(vec![col("c")], vec![max(col("a"))])?
2115 .build()?;
2116
2117 assert_fields_eq(&plan, vec!["c", "max(test.a)"]);
2118
2119 let plan = optimize(plan).expect("failed to optimize plan");
2120 assert_optimized_plan_equal!(
2121 plan,
2122 @r"
2123 Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]
2124 Filter: test.c > Int32(1)
2125 Projection: test.c, test.a
2126 TableScan: test projection=[a, c]
2127 "
2128 )
2129 }
2130
2131 #[test]
2133 fn table_unused_projection() -> Result<()> {
2134 let table_scan = test_table_scan()?;
2135 assert_eq!(3, table_scan.schema().fields().len());
2136 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2137
2138 let plan = LogicalPlanBuilder::from(table_scan)
2140 .project(vec![col("b")])?
2141 .project(vec![lit(1).alias("a")])?
2142 .build()?;
2143
2144 assert_fields_eq(&plan, vec!["a"]);
2145
2146 assert_optimized_plan_equal!(
2147 plan,
2148 @r"
2149 Projection: Int32(1) AS a
2150 TableScan: test projection=[]
2151 "
2152 )
2153 }
2154
2155 #[test]
2156 fn table_full_filter_pushdown() -> Result<()> {
2157 let schema = Schema::new(test_table_scan_fields());
2158
2159 let table_scan = table_scan_with_filters(
2160 Some("test"),
2161 &schema,
2162 None,
2163 vec![col("b").eq(lit(1))],
2164 )?
2165 .build()?;
2166 assert_eq!(3, table_scan.schema().fields().len());
2167 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2168
2169 let plan = LogicalPlanBuilder::from(table_scan)
2171 .project(vec![col("b")])?
2172 .project(vec![lit(1).alias("a")])?
2173 .build()?;
2174
2175 assert_fields_eq(&plan, vec!["a"]);
2176
2177 assert_optimized_plan_equal!(
2178 plan,
2179 @r"
2180 Projection: Int32(1) AS a
2181 TableScan: test projection=[], full_filters=[b = Int32(1)]
2182 "
2183 )
2184 }
2185
2186 #[test]
2188 fn test_double_optimization() -> Result<()> {
2189 let table_scan = test_table_scan()?;
2190
2191 let plan = LogicalPlanBuilder::from(table_scan)
2192 .project(vec![col("b")])?
2193 .project(vec![lit(1).alias("a")])?
2194 .build()?;
2195
2196 let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
2197 let optimized_plan2 =
2198 optimize(optimized_plan1.clone()).expect("failed to optimize plan");
2199
2200 let formatted_plan1 = format!("{optimized_plan1:?}");
2201 let formatted_plan2 = format!("{optimized_plan2:?}");
2202 assert_eq!(formatted_plan1, formatted_plan2);
2203 Ok(())
2204 }
2205
2206 #[test]
2208 fn table_unused_aggregate() -> Result<()> {
2209 let table_scan = test_table_scan()?;
2210 assert_eq!(3, table_scan.schema().fields().len());
2211 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2212
2213 let plan = LogicalPlanBuilder::from(table_scan)
2215 .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
2216 .filter(col("c").gt(lit(1)))?
2217 .project(vec![col("c"), col("a"), col("max(test.b)")])?
2218 .build()?;
2219
2220 assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]);
2221
2222 assert_optimized_plan_equal!(
2223 plan,
2224 @r"
2225 Projection: test.c, test.a, max(test.b)
2226 Filter: test.c > Int32(1)
2227 Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]
2228 TableScan: test projection=[a, b, c]
2229 "
2230 )
2231 }
2232
2233 #[test]
2234 fn aggregate_filter_pushdown() -> Result<()> {
2235 let table_scan = test_table_scan()?;
2236 let aggr_with_filter = count_udaf()
2237 .call(vec![col("b")])
2238 .filter(col("c").gt(lit(42)))
2239 .build()?;
2240 let plan = LogicalPlanBuilder::from(table_scan)
2241 .aggregate(
2242 vec![col("a")],
2243 vec![count(col("b")), aggr_with_filter.alias("count2")],
2244 )?
2245 .build()?;
2246
2247 assert_optimized_plan_equal!(
2248 plan,
2249 @r"
2250 Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]
2251 TableScan: test projection=[a, b, c]
2252 "
2253 )
2254 }
2255
2256 #[test]
2257 fn pushdown_through_distinct() -> Result<()> {
2258 let table_scan = test_table_scan()?;
2259
2260 let plan = LogicalPlanBuilder::from(table_scan)
2261 .project(vec![col("a"), col("b")])?
2262 .distinct()?
2263 .project(vec![col("a")])?
2264 .build()?;
2265
2266 assert_optimized_plan_equal!(
2267 plan,
2268 @r"
2269 Projection: test.a
2270 Distinct:
2271 TableScan: test projection=[a, b]
2272 "
2273 )
2274 }
2275
2276 #[test]
2277 fn test_window() -> Result<()> {
2278 let table_scan = test_table_scan()?;
2279
2280 let max1 = Expr::from(expr::WindowFunction::new(
2281 WindowFunctionDefinition::AggregateUDF(max_udaf()),
2282 vec![col("test.a")],
2283 ))
2284 .partition_by(vec![col("test.b")])
2285 .build()
2286 .unwrap();
2287
2288 let max2 = Expr::from(expr::WindowFunction::new(
2289 WindowFunctionDefinition::AggregateUDF(max_udaf()),
2290 vec![col("test.b")],
2291 ));
2292 let col1 = col(max1.schema_name().to_string());
2293 let col2 = col(max2.schema_name().to_string());
2294
2295 let plan = LogicalPlanBuilder::from(table_scan)
2296 .window(vec![max1])?
2297 .window(vec![max2])?
2298 .project(vec![col1, col2])?
2299 .build()?;
2300
2301 assert_optimized_plan_equal!(
2302 plan,
2303 @r"
2304 Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2305 WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2306 Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2307 WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2308 TableScan: test projection=[a, b]
2309 "
2310 )
2311 }
2312
2313 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
2314
2315 fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
2316 let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
2317 let optimized_plan =
2318 optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
2319 Ok(optimized_plan)
2320 }
2321}