1mod required_indices;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24use std::collections::HashSet;
25use std::sync::Arc;
26
27use datafusion_common::{
28 Column, DFSchema, HashMap, JoinType, Result, assert_eq_or_internal_err,
29 get_required_group_by_exprs_indices, internal_datafusion_err, internal_err,
30};
31use datafusion_expr::expr::Alias;
32use datafusion_expr::{
33 Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, Unnest, Window,
34 logical_plan::LogicalPlan,
35};
36
37use crate::optimize_projections::required_indices::RequiredIndices;
38use crate::utils::NamePreserver;
39use datafusion_common::tree_node::{
40 Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
41};
42
43#[derive(Default, Debug)]
77pub struct OptimizeProjections {}
78
79impl OptimizeProjections {
80 #[expect(missing_docs)]
81 pub fn new() -> Self {
82 Self {}
83 }
84}
85
86impl OptimizerRule for OptimizeProjections {
87 fn name(&self) -> &str {
88 "optimize_projections"
89 }
90
91 fn apply_order(&self) -> Option<ApplyOrder> {
92 None
93 }
94
95 fn supports_rewrite(&self) -> bool {
96 true
97 }
98
99 fn rewrite(
100 &self,
101 plan: LogicalPlan,
102 config: &dyn OptimizerConfig,
103 ) -> Result<Transformed<LogicalPlan>> {
104 let indices = RequiredIndices::new_for_all_exprs(&plan);
106 optimize_projections(plan, config, indices)
107 }
108}
109
110#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
130fn optimize_projections(
131 plan: LogicalPlan,
132 config: &dyn OptimizerConfig,
133 indices: RequiredIndices,
134) -> Result<Transformed<LogicalPlan>> {
135 match plan {
138 LogicalPlan::Projection(proj) => {
139 return merge_consecutive_projections(proj)?.transform_data(|proj| {
140 rewrite_projection_given_requirements(proj, config, &indices)
141 });
142 }
143 LogicalPlan::Aggregate(aggregate) => {
144 let n_group_exprs = aggregate.group_expr_len()?;
146 let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
149
150 let group_by_expr_existing = aggregate
152 .group_expr
153 .iter()
154 .map(|group_by_expr| group_by_expr.schema_name().to_string())
155 .collect::<Vec<_>>();
156
157 let new_group_bys = if let Some(simplest_groupby_indices) =
158 get_required_group_by_exprs_indices(
159 aggregate.input.schema(),
160 &group_by_expr_existing,
161 ) {
162 group_by_reqs
166 .append(&simplest_groupby_indices)
167 .get_at_indices(&aggregate.group_expr)
168 } else {
169 aggregate.group_expr
170 };
171
172 let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr);
175
176 if new_group_bys.is_empty() && new_aggr_expr.is_empty() {
177 return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
179 EmptyRelation {
180 produce_one_row: true,
181 schema: Arc::new(DFSchema::empty()),
182 },
183 )));
184 }
185
186 let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
187 let schema = aggregate.input.schema();
188 let necessary_indices =
189 RequiredIndices::new().with_exprs(schema, all_exprs_iter);
190 let necessary_exprs = necessary_indices.get_required_exprs(schema);
191
192 return optimize_projections(
193 Arc::unwrap_or_clone(aggregate.input),
194 config,
195 necessary_indices,
196 )?
197 .transform_data(|aggregate_input| {
198 add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)
203 })?
204 .map_data(|aggregate_input| {
205 Aggregate::try_new(
208 Arc::new(aggregate_input),
209 new_group_bys,
210 new_aggr_expr,
211 )
212 .map(LogicalPlan::Aggregate)
213 });
214 }
215 LogicalPlan::Window(window) => {
216 let input_schema = Arc::clone(window.input.schema());
217 let n_input_fields = input_schema.fields().len();
219 let (child_reqs, window_reqs) = indices.split_off(n_input_fields);
222
223 let new_window_expr = window_reqs.get_at_indices(&window.window_expr);
226
227 let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
230
231 return optimize_projections(
232 Arc::unwrap_or_clone(window.input),
233 config,
234 required_indices.clone(),
235 )?
236 .transform_data(|window_child| {
237 if new_window_expr.is_empty() {
238 Ok(Transformed::no(window_child))
240 } else {
241 let required_exprs =
245 required_indices.get_required_exprs(&input_schema);
246 let window_child =
247 add_projection_on_top_if_helpful(window_child, required_exprs)?
248 .data;
249 Window::try_new(new_window_expr, Arc::new(window_child))
250 .map(LogicalPlan::Window)
251 .map(Transformed::yes)
252 }
253 });
254 }
255 LogicalPlan::TableScan(table_scan) => {
256 let TableScan {
257 table_name,
258 source,
259 projection,
260 filters,
261 fetch,
262 projected_schema: _,
263 } = table_scan;
264
265 let projection = match &projection {
268 Some(projection) => indices.into_mapped_indices(|idx| projection[idx]),
269 None => indices.into_inner(),
270 };
271 let new_scan =
272 TableScan::try_new(table_name, source, Some(projection), filters, fetch)?;
273
274 return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan)));
275 }
276 _ => {}
278 };
279
280 let mut child_required_indices: Vec<RequiredIndices> = match &plan {
283 LogicalPlan::Sort(_)
284 | LogicalPlan::Filter(_)
285 | LogicalPlan::Repartition(_)
286 | LogicalPlan::Union(_)
287 | LogicalPlan::SubqueryAlias(_)
288 | LogicalPlan::Distinct(Distinct::On(_)) => {
289 plan.inputs()
294 .into_iter()
295 .map(|input| {
296 indices
297 .clone()
298 .with_projection_beneficial()
299 .with_plan_exprs(&plan, input.schema())
300 })
301 .collect::<Result<_>>()?
302 }
303 LogicalPlan::Limit(_) => {
304 plan.inputs()
309 .into_iter()
310 .map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
311 .collect::<Result<_>>()?
312 }
313 LogicalPlan::Copy(_)
314 | LogicalPlan::Ddl(_)
315 | LogicalPlan::Dml(_)
316 | LogicalPlan::Explain(_)
317 | LogicalPlan::Analyze(_)
318 | LogicalPlan::Subquery(_)
319 | LogicalPlan::Statement(_)
320 | LogicalPlan::Distinct(Distinct::All(_)) => {
321 plan.inputs()
327 .into_iter()
328 .map(RequiredIndices::new_for_all_exprs)
329 .collect()
330 }
331 LogicalPlan::Extension(extension) => {
332 let Some(necessary_children_indices) =
333 extension.node.necessary_children_exprs(indices.indices())
334 else {
335 return Ok(Transformed::no(plan));
337 };
338 let children = extension.node.inputs();
339 assert_eq_or_internal_err!(
340 children.len(),
341 necessary_children_indices.len(),
342 "Inconsistent length between children and necessary children indices. \
343 Make sure `.necessary_children_exprs` implementation of the \
344 `UserDefinedLogicalNode` is consistent with actual children length \
345 for the node."
346 );
347 children
348 .into_iter()
349 .zip(necessary_children_indices)
350 .map(|(child, necessary_indices)| {
351 RequiredIndices::new_from_indices(necessary_indices)
352 .with_plan_exprs(&plan, child.schema())
353 })
354 .collect::<Result<Vec<_>>>()?
355 }
356 LogicalPlan::EmptyRelation(_)
357 | LogicalPlan::Values(_)
358 | LogicalPlan::DescribeTable(_) => {
359 return Ok(Transformed::no(plan));
361 }
362 LogicalPlan::RecursiveQuery(recursive) => {
363 if plan_contains_other_subqueries(
367 recursive.static_term.as_ref(),
368 &recursive.name,
369 ) || plan_contains_other_subqueries(
370 recursive.recursive_term.as_ref(),
371 &recursive.name,
372 ) {
373 return Ok(Transformed::no(plan));
374 }
375
376 plan.inputs()
377 .into_iter()
378 .map(|input| {
379 indices
380 .clone()
381 .with_projection_beneficial()
382 .with_plan_exprs(&plan, input.schema())
383 })
384 .collect::<Result<Vec<_>>>()?
385 }
386 LogicalPlan::Join(join) => {
387 let left_len = join.left.schema().fields().len();
388 let (left_req_indices, right_req_indices) =
389 split_join_requirements(left_len, indices, &join.join_type);
390 let left_indices =
391 left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
392 let right_indices =
393 right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
394 vec![
397 left_indices.with_projection_beneficial(),
398 right_indices.with_projection_beneficial(),
399 ]
400 }
401 LogicalPlan::Projection(_)
403 | LogicalPlan::Aggregate(_)
404 | LogicalPlan::Window(_)
405 | LogicalPlan::TableScan(_) => {
406 return internal_err!(
407 "OptimizeProjection: should have handled in the match statement above"
408 );
409 }
410 LogicalPlan::Unnest(Unnest {
411 input,
412 dependency_indices,
413 ..
414 }) => {
415 let required_indices =
417 RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
418
419 let mut additional_necessary_child_indices = Vec::new();
421 indices.indices().iter().for_each(|idx| {
422 if let Some(index) = dependency_indices.get(*idx) {
423 additional_necessary_child_indices.push(*index);
424 }
425 });
426 vec![required_indices.append(&additional_necessary_child_indices)]
427 }
428 };
429
430 child_required_indices.reverse();
433 assert_eq_or_internal_err!(
434 child_required_indices.len(),
435 plan.inputs().len(),
436 "OptimizeProjection: child_required_indices length mismatch with plan inputs"
437 );
438
439 let transformed_plan = plan.map_children(|child| {
441 let required_indices = child_required_indices.pop().ok_or_else(|| {
442 internal_datafusion_err!(
443 "Unexpected number of required_indices in OptimizeProjections rule"
444 )
445 })?;
446
447 let projection_beneficial = required_indices.projection_beneficial();
448 let project_exprs = required_indices.get_required_exprs(child.schema());
449
450 optimize_projections(child, config, required_indices)?.transform_data(
451 |new_input| {
452 if projection_beneficial {
453 add_projection_on_top_if_helpful(new_input, project_exprs)
454 } else {
455 Ok(Transformed::no(new_input))
456 }
457 },
458 )
459 })?;
460
461 if transformed_plan.transformed {
463 transformed_plan.map_data(|plan| plan.recompute_schema())
464 } else {
465 Ok(transformed_plan)
466 }
467}
468
469fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Projection>> {
502 let Projection {
503 expr,
504 input,
505 schema,
506 ..
507 } = proj;
508 let LogicalPlan::Projection(prev_projection) = input.as_ref() else {
509 return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
510 };
511
512 if prev_projection.expr == expr {
515 return Projection::try_new_with_schema(
516 expr,
517 Arc::clone(&prev_projection.input),
518 schema,
519 )
520 .map(Transformed::yes);
521 }
522
523 let mut column_referral_map = HashMap::<&Column, usize>::new();
525 expr.iter()
526 .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map));
527
528 if column_referral_map.into_iter().any(|(col, usage)| {
532 usage > 1
533 && !prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()]
534 .placement()
535 .should_push_to_leaves()
536 }) {
537 return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
539 }
540
541 let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else {
542 unreachable!();
544 };
545
546 let name_preserver = NamePreserver::new_for_projection();
549 let mut original_names = vec![];
550 let new_exprs = expr.map_elements(|expr| {
551 original_names.push(name_preserver.save(&expr));
552
553 match expr {
555 Expr::Alias(Alias {
556 expr,
557 relation,
558 name,
559 metadata,
560 }) => rewrite_expr(*expr, &prev_projection).map(|result| {
561 result.update_data(|expr| {
562 if metadata.is_none() && expr.schema_name().to_string() == name {
569 expr
570 } else {
571 Expr::Alias(
572 Alias::new(expr, relation, name).with_metadata(metadata),
573 )
574 }
575 })
576 }),
577 e => rewrite_expr(e, &prev_projection),
578 }
579 })?;
580
581 if new_exprs.transformed {
584 let new_exprs = new_exprs
586 .data
587 .into_iter()
588 .zip(original_names)
589 .map(|(expr, original_name)| original_name.restore(expr))
590 .collect::<Vec<_>>();
591 Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
592 } else {
593 let input = Arc::new(LogicalPlan::Projection(prev_projection));
595 Projection::try_new_with_schema(new_exprs.data, input, schema)
596 .map(Transformed::no)
597 }
598}
599
600fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
646 expr.transform_up(|expr| {
647 match expr {
648 Expr::Alias(alias) => {
650 match alias
651 .metadata
652 .as_ref()
653 .map(|h| h.is_empty())
654 .unwrap_or(true)
655 {
656 true => Ok(Transformed::yes(*alias.expr)),
657 false => Ok(Transformed::no(Expr::Alias(alias))),
658 }
659 }
660 Expr::Column(col) => {
661 let idx = input.schema.index_of_column(&col)?;
663 let input_expr = input.expr[idx].clone().unalias_nested().data;
671 Ok(Transformed::yes(input_expr))
672 }
673 _ => Ok(Transformed::no(expr)),
675 }
676 })
677}
678
679fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
688 expr.apply(|expr| {
690 match expr {
691 Expr::OuterReferenceColumn(_, col) => {
692 columns.insert(col);
693 }
694 Expr::ScalarSubquery(subquery) => {
695 outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
696 }
697 Expr::Exists(exists) => {
698 outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
699 }
700 Expr::InSubquery(insubquery) => {
701 outer_columns_helper_multi(
702 &insubquery.subquery.outer_ref_columns,
703 columns,
704 );
705 }
706 _ => {}
707 };
708 Ok(TreeNodeRecursion::Continue)
709 })
710 .unwrap();
712}
713
714fn outer_columns_helper_multi<'a, 'b>(
723 exprs: impl IntoIterator<Item = &'a Expr>,
724 columns: &'b mut HashSet<&'a Column>,
725) {
726 exprs.into_iter().for_each(|e| outer_columns(e, columns));
727}
728
729fn split_join_requirements(
759 left_len: usize,
760 indices: RequiredIndices,
761 join_type: &JoinType,
762) -> (RequiredIndices, RequiredIndices) {
763 match join_type {
764 JoinType::Inner
766 | JoinType::Left
767 | JoinType::Right
768 | JoinType::Full
769 | JoinType::LeftMark
770 | JoinType::RightMark => {
771 indices.split_off(left_len)
774 }
775 JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
777 JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices),
780 }
781}
782
783fn add_projection_on_top_if_helpful(
801 plan: LogicalPlan,
802 project_exprs: Vec<Expr>,
803) -> Result<Transformed<LogicalPlan>> {
804 if project_exprs.len() >= plan.schema().fields().len() {
806 Ok(Transformed::no(plan))
807 } else {
808 Projection::try_new(project_exprs, Arc::new(plan))
809 .map(LogicalPlan::Projection)
810 .map(Transformed::yes)
811 }
812}
813
814fn rewrite_projection_given_requirements(
832 proj: Projection,
833 config: &dyn OptimizerConfig,
834 indices: &RequiredIndices,
835) -> Result<Transformed<LogicalPlan>> {
836 let Projection { expr, input, .. } = proj;
837
838 let exprs_used = indices.get_at_indices(&expr);
839
840 let required_indices =
841 RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
842
843 optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?
846 .transform_data(|input| {
847 if is_projection_unnecessary(&input, &exprs_used)? {
848 Ok(Transformed::yes(input))
849 } else {
850 Projection::try_new(exprs_used, Arc::new(input))
851 .map(LogicalPlan::Projection)
852 .map(Transformed::yes)
853 }
854 })
855}
856
857pub fn is_projection_unnecessary(
861 input: &LogicalPlan,
862 proj_exprs: &[Expr],
863) -> Result<bool> {
864 if proj_exprs.len() != input.schema().fields().len() {
866 return Ok(false);
867 }
868 Ok(input.schema().iter().zip(proj_exprs.iter()).all(
869 |((field_relation, field_name), expr)| {
870 if let Expr::Column(col) = expr {
872 col.relation.as_ref() == field_relation && col.name.eq(field_name.name())
873 } else {
874 false
875 }
876 },
877 ))
878}
879
880fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
886 if let LogicalPlan::SubqueryAlias(alias) = plan
887 && alias.alias.table() != cte_name
888 && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
889 {
890 return true;
891 }
892
893 let mut found = false;
894 plan.apply_expressions(|expr| {
895 if expr_contains_subquery(expr) {
896 found = true;
897 Ok(TreeNodeRecursion::Stop)
898 } else {
899 Ok(TreeNodeRecursion::Continue)
900 }
901 })
902 .expect("expression traversal never fails");
903 if found {
904 return true;
905 }
906
907 plan.inputs()
908 .into_iter()
909 .any(|child| plan_contains_other_subqueries(child, cte_name))
910}
911
912fn expr_contains_subquery(expr: &Expr) -> bool {
913 expr.exists(|e| match e {
914 Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
915 _ => Ok(false),
916 })
917 .unwrap()
919}
920
921fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool {
922 match plan {
923 LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name,
924 LogicalPlan::SubqueryAlias(alias) => {
925 subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
926 }
927 _ => {
928 let inputs = plan.inputs();
929 if inputs.len() == 1 {
930 subquery_alias_targets_recursive_cte(inputs[0], cte_name)
931 } else {
932 false
933 }
934 }
935 }
936}
937
938#[cfg(test)]
939mod tests {
940 use std::cmp::Ordering;
941 use std::collections::HashMap;
942 use std::fmt::Formatter;
943 use std::ops::Add;
944 use std::sync::Arc;
945 use std::vec;
946
947 use crate::optimize_projections::OptimizeProjections;
948 use crate::optimizer::Optimizer;
949 use crate::test::{
950 assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields,
951 test_table_scan_with_name,
952 };
953 use crate::{OptimizerContext, OptimizerRule};
954 use arrow::datatypes::{DataType, Field, Schema};
955 use datafusion_common::{
956 Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
957 };
958 use datafusion_expr::ExprFunctionExt;
959 use datafusion_expr::{
960 BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, Projection,
961 UserDefinedLogicalNodeCore, WindowFunctionDefinition, binary_expr,
962 build_join_schema,
963 builder::table_scan_with_filters,
964 col,
965 expr::{self, Cast},
966 lit,
967 logical_plan::{builder::LogicalPlanBuilder, table_scan},
968 not, try_cast, when,
969 };
970 use insta::assert_snapshot;
971
972 use crate::assert_optimized_plan_eq_snapshot;
973 use datafusion_functions_aggregate::count::count_udaf;
974 use datafusion_functions_aggregate::expr_fn::{count, max, min};
975 use datafusion_functions_aggregate::min_max::max_udaf;
976
977 macro_rules! assert_optimized_plan_equal {
978 (
979 $plan:expr,
980 @ $expected:literal $(,)?
981 ) => {{
982 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
983 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeProjections::new())];
984 assert_optimized_plan_eq_snapshot!(
985 optimizer_ctx,
986 rules,
987 $plan,
988 @ $expected,
989 )
990 }};
991 }
992
993 #[derive(Debug, Hash, PartialEq, Eq)]
994 struct NoOpUserDefined {
995 exprs: Vec<Expr>,
996 schema: DFSchemaRef,
997 input: Arc<LogicalPlan>,
998 }
999
1000 impl NoOpUserDefined {
1001 fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
1002 Self {
1003 exprs: vec![],
1004 schema,
1005 input,
1006 }
1007 }
1008
1009 fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
1010 self.exprs = exprs;
1011 self
1012 }
1013 }
1014
1015 impl PartialOrd for NoOpUserDefined {
1017 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1018 match self.exprs.partial_cmp(&other.exprs) {
1019 Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
1020 cmp => cmp,
1021 }
1022 .filter(|cmp| *cmp != Ordering::Equal || self == other)
1024 }
1025 }
1026
1027 impl UserDefinedLogicalNodeCore for NoOpUserDefined {
1028 fn name(&self) -> &str {
1029 "NoOpUserDefined"
1030 }
1031
1032 fn inputs(&self) -> Vec<&LogicalPlan> {
1033 vec![&self.input]
1034 }
1035
1036 fn schema(&self) -> &DFSchemaRef {
1037 &self.schema
1038 }
1039
1040 fn expressions(&self) -> Vec<Expr> {
1041 self.exprs.clone()
1042 }
1043
1044 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1045 write!(f, "NoOpUserDefined")
1046 }
1047
1048 fn with_exprs_and_inputs(
1049 &self,
1050 exprs: Vec<Expr>,
1051 mut inputs: Vec<LogicalPlan>,
1052 ) -> Result<Self> {
1053 Ok(Self {
1054 exprs,
1055 input: Arc::new(inputs.swap_remove(0)),
1056 schema: Arc::clone(&self.schema),
1057 })
1058 }
1059
1060 fn necessary_children_exprs(
1061 &self,
1062 output_columns: &[usize],
1063 ) -> Option<Vec<Vec<usize>>> {
1064 Some(vec![output_columns.to_vec()])
1066 }
1067
1068 fn supports_limit_pushdown(&self) -> bool {
1069 false }
1071 }
1072
1073 #[derive(Debug, Hash, PartialEq, Eq)]
1074 struct UserDefinedCrossJoin {
1075 exprs: Vec<Expr>,
1076 schema: DFSchemaRef,
1077 left_child: Arc<LogicalPlan>,
1078 right_child: Arc<LogicalPlan>,
1079 }
1080
1081 impl UserDefinedCrossJoin {
1082 fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
1083 let left_schema = left_child.schema();
1084 let right_schema = right_child.schema();
1085 let schema = Arc::new(
1086 build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
1087 );
1088 Self {
1089 exprs: vec![],
1090 schema,
1091 left_child,
1092 right_child,
1093 }
1094 }
1095 }
1096
1097 impl PartialOrd for UserDefinedCrossJoin {
1099 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1100 match self.exprs.partial_cmp(&other.exprs) {
1101 Some(Ordering::Equal) => {
1102 match self.left_child.partial_cmp(&other.left_child) {
1103 Some(Ordering::Equal) => {
1104 self.right_child.partial_cmp(&other.right_child)
1105 }
1106 cmp => cmp,
1107 }
1108 }
1109 cmp => cmp,
1110 }
1111 .filter(|cmp| *cmp != Ordering::Equal || self == other)
1113 }
1114 }
1115
1116 impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
1117 fn name(&self) -> &str {
1118 "UserDefinedCrossJoin"
1119 }
1120
1121 fn inputs(&self) -> Vec<&LogicalPlan> {
1122 vec![&self.left_child, &self.right_child]
1123 }
1124
1125 fn schema(&self) -> &DFSchemaRef {
1126 &self.schema
1127 }
1128
1129 fn expressions(&self) -> Vec<Expr> {
1130 self.exprs.clone()
1131 }
1132
1133 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1134 write!(f, "UserDefinedCrossJoin")
1135 }
1136
1137 fn with_exprs_and_inputs(
1138 &self,
1139 exprs: Vec<Expr>,
1140 mut inputs: Vec<LogicalPlan>,
1141 ) -> Result<Self> {
1142 assert_eq!(inputs.len(), 2);
1143 Ok(Self {
1144 exprs,
1145 left_child: Arc::new(inputs.remove(0)),
1146 right_child: Arc::new(inputs.remove(0)),
1147 schema: Arc::clone(&self.schema),
1148 })
1149 }
1150
1151 fn necessary_children_exprs(
1152 &self,
1153 output_columns: &[usize],
1154 ) -> Option<Vec<Vec<usize>>> {
1155 let left_child_len = self.left_child.schema().fields().len();
1156 let mut left_reqs = vec![];
1157 let mut right_reqs = vec![];
1158 for &out_idx in output_columns {
1159 if out_idx < left_child_len {
1160 left_reqs.push(out_idx);
1161 } else {
1162 right_reqs.push(out_idx - left_child_len)
1165 }
1166 }
1167 Some(vec![left_reqs, right_reqs])
1168 }
1169
1170 fn supports_limit_pushdown(&self) -> bool {
1171 false }
1173 }
1174
1175 #[test]
1176 fn merge_two_projection() -> Result<()> {
1177 let table_scan = test_table_scan()?;
1178 let plan = LogicalPlanBuilder::from(table_scan)
1179 .project(vec![col("a")])?
1180 .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1181 .build()?;
1182
1183 assert_optimized_plan_equal!(
1184 plan,
1185 @r"
1186 Projection: Int32(1) + test.a
1187 TableScan: test projection=[a]
1188 "
1189 )
1190 }
1191
1192 #[test]
1193 fn merge_three_projection() -> Result<()> {
1194 let table_scan = test_table_scan()?;
1195 let plan = LogicalPlanBuilder::from(table_scan)
1196 .project(vec![col("a"), col("b")])?
1197 .project(vec![col("a")])?
1198 .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1199 .build()?;
1200
1201 assert_optimized_plan_equal!(
1202 plan,
1203 @r"
1204 Projection: Int32(1) + test.a
1205 TableScan: test projection=[a]
1206 "
1207 )
1208 }
1209
1210 #[test]
1211 fn merge_alias() -> Result<()> {
1212 let table_scan = test_table_scan()?;
1213 let plan = LogicalPlanBuilder::from(table_scan)
1214 .project(vec![col("a")])?
1215 .project(vec![col("a").alias("alias")])?
1216 .build()?;
1217
1218 assert_optimized_plan_equal!(
1219 plan,
1220 @r"
1221 Projection: test.a AS alias
1222 TableScan: test projection=[a]
1223 "
1224 )
1225 }
1226
1227 #[test]
1228 fn merge_nested_alias() -> Result<()> {
1229 let table_scan = test_table_scan()?;
1230 let plan = LogicalPlanBuilder::from(table_scan)
1231 .project(vec![col("a").alias("alias1").alias("alias2")])?
1232 .project(vec![col("alias2").alias("alias")])?
1233 .build()?;
1234
1235 assert_optimized_plan_equal!(
1236 plan,
1237 @r"
1238 Projection: test.a AS alias
1239 TableScan: test projection=[a]
1240 "
1241 )
1242 }
1243
1244 #[test]
1245 fn test_nested_count() -> Result<()> {
1246 let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
1247
1248 let groups: Vec<Expr> = vec![];
1249
1250 let plan = table_scan(TableReference::none(), &schema, None)
1251 .unwrap()
1252 .aggregate(groups.clone(), vec![count(lit(1))])
1253 .unwrap()
1254 .aggregate(groups, vec![count(lit(1))])
1255 .unwrap()
1256 .build()
1257 .unwrap();
1258
1259 assert_optimized_plan_equal!(
1260 plan,
1261 @r"
1262 Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1263 EmptyRelation: rows=1
1264 "
1265 )
1266 }
1267
1268 #[test]
1269 fn test_neg_push_down() -> Result<()> {
1270 let table_scan = test_table_scan()?;
1271 let plan = LogicalPlanBuilder::from(table_scan)
1272 .project(vec![-col("a")])?
1273 .build()?;
1274
1275 assert_optimized_plan_equal!(
1276 plan,
1277 @r"
1278 Projection: (- test.a)
1279 TableScan: test projection=[a]
1280 "
1281 )
1282 }
1283
1284 #[test]
1285 fn test_is_null() -> Result<()> {
1286 let table_scan = test_table_scan()?;
1287 let plan = LogicalPlanBuilder::from(table_scan)
1288 .project(vec![col("a").is_null()])?
1289 .build()?;
1290
1291 assert_optimized_plan_equal!(
1292 plan,
1293 @r"
1294 Projection: test.a IS NULL
1295 TableScan: test projection=[a]
1296 "
1297 )
1298 }
1299
1300 #[test]
1301 fn test_is_not_null() -> Result<()> {
1302 let table_scan = test_table_scan()?;
1303 let plan = LogicalPlanBuilder::from(table_scan)
1304 .project(vec![col("a").is_not_null()])?
1305 .build()?;
1306
1307 assert_optimized_plan_equal!(
1308 plan,
1309 @r"
1310 Projection: test.a IS NOT NULL
1311 TableScan: test projection=[a]
1312 "
1313 )
1314 }
1315
1316 #[test]
1317 fn test_is_true() -> Result<()> {
1318 let table_scan = test_table_scan()?;
1319 let plan = LogicalPlanBuilder::from(table_scan)
1320 .project(vec![col("a").is_true()])?
1321 .build()?;
1322
1323 assert_optimized_plan_equal!(
1324 plan,
1325 @r"
1326 Projection: test.a IS TRUE
1327 TableScan: test projection=[a]
1328 "
1329 )
1330 }
1331
1332 #[test]
1333 fn test_is_not_true() -> Result<()> {
1334 let table_scan = test_table_scan()?;
1335 let plan = LogicalPlanBuilder::from(table_scan)
1336 .project(vec![col("a").is_not_true()])?
1337 .build()?;
1338
1339 assert_optimized_plan_equal!(
1340 plan,
1341 @r"
1342 Projection: test.a IS NOT TRUE
1343 TableScan: test projection=[a]
1344 "
1345 )
1346 }
1347
1348 #[test]
1349 fn test_is_false() -> Result<()> {
1350 let table_scan = test_table_scan()?;
1351 let plan = LogicalPlanBuilder::from(table_scan)
1352 .project(vec![col("a").is_false()])?
1353 .build()?;
1354
1355 assert_optimized_plan_equal!(
1356 plan,
1357 @r"
1358 Projection: test.a IS FALSE
1359 TableScan: test projection=[a]
1360 "
1361 )
1362 }
1363
1364 #[test]
1365 fn test_is_not_false() -> Result<()> {
1366 let table_scan = test_table_scan()?;
1367 let plan = LogicalPlanBuilder::from(table_scan)
1368 .project(vec![col("a").is_not_false()])?
1369 .build()?;
1370
1371 assert_optimized_plan_equal!(
1372 plan,
1373 @r"
1374 Projection: test.a IS NOT FALSE
1375 TableScan: test projection=[a]
1376 "
1377 )
1378 }
1379
1380 #[test]
1381 fn test_is_unknown() -> Result<()> {
1382 let table_scan = test_table_scan()?;
1383 let plan = LogicalPlanBuilder::from(table_scan)
1384 .project(vec![col("a").is_unknown()])?
1385 .build()?;
1386
1387 assert_optimized_plan_equal!(
1388 plan,
1389 @r"
1390 Projection: test.a IS UNKNOWN
1391 TableScan: test projection=[a]
1392 "
1393 )
1394 }
1395
1396 #[test]
1397 fn test_is_not_unknown() -> Result<()> {
1398 let table_scan = test_table_scan()?;
1399 let plan = LogicalPlanBuilder::from(table_scan)
1400 .project(vec![col("a").is_not_unknown()])?
1401 .build()?;
1402
1403 assert_optimized_plan_equal!(
1404 plan,
1405 @r"
1406 Projection: test.a IS NOT UNKNOWN
1407 TableScan: test projection=[a]
1408 "
1409 )
1410 }
1411
1412 #[test]
1413 fn test_not() -> Result<()> {
1414 let table_scan = test_table_scan()?;
1415 let plan = LogicalPlanBuilder::from(table_scan)
1416 .project(vec![not(col("a"))])?
1417 .build()?;
1418
1419 assert_optimized_plan_equal!(
1420 plan,
1421 @r"
1422 Projection: NOT test.a
1423 TableScan: test projection=[a]
1424 "
1425 )
1426 }
1427
1428 #[test]
1429 fn test_try_cast() -> Result<()> {
1430 let table_scan = test_table_scan()?;
1431 let plan = LogicalPlanBuilder::from(table_scan)
1432 .project(vec![try_cast(col("a"), DataType::Float64)])?
1433 .build()?;
1434
1435 assert_optimized_plan_equal!(
1436 plan,
1437 @r"
1438 Projection: TRY_CAST(test.a AS Float64)
1439 TableScan: test projection=[a]
1440 "
1441 )
1442 }
1443
1444 #[test]
1445 fn test_similar_to() -> Result<()> {
1446 let table_scan = test_table_scan()?;
1447 let expr = Box::new(col("a"));
1448 let pattern = Box::new(lit("[0-9]"));
1449 let similar_to_expr =
1450 Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1451 let plan = LogicalPlanBuilder::from(table_scan)
1452 .project(vec![similar_to_expr])?
1453 .build()?;
1454
1455 assert_optimized_plan_equal!(
1456 plan,
1457 @r#"
1458 Projection: test.a SIMILAR TO Utf8("[0-9]")
1459 TableScan: test projection=[a]
1460 "#
1461 )
1462 }
1463
1464 #[test]
1465 fn test_between() -> Result<()> {
1466 let table_scan = test_table_scan()?;
1467 let plan = LogicalPlanBuilder::from(table_scan)
1468 .project(vec![col("a").between(lit(1), lit(3))])?
1469 .build()?;
1470
1471 assert_optimized_plan_equal!(
1472 plan,
1473 @r"
1474 Projection: test.a BETWEEN Int32(1) AND Int32(3)
1475 TableScan: test projection=[a]
1476 "
1477 )
1478 }
1479
1480 #[test]
1482 fn test_case_merged() -> Result<()> {
1483 let table_scan = test_table_scan()?;
1484 let plan = LogicalPlanBuilder::from(table_scan)
1485 .project(vec![col("a"), lit(0).alias("d")])?
1486 .project(vec![
1487 col("a"),
1488 when(col("a").eq(lit(1)), lit(10))
1489 .otherwise(col("d"))?
1490 .alias("d"),
1491 ])?
1492 .build()?;
1493
1494 assert_optimized_plan_equal!(
1495 plan,
1496 @r"
1497 Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d
1498 TableScan: test projection=[a]
1499 "
1500 )
1501 }
1502
1503 #[test]
1506 fn test_derived_column() -> Result<()> {
1507 let table_scan = test_table_scan()?;
1508 let plan = LogicalPlanBuilder::from(table_scan)
1509 .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])?
1510 .project(vec![
1511 col("a"),
1512 when(col("a").eq(lit(1)), lit(10))
1513 .otherwise(col("d"))?
1514 .alias("d"),
1515 ])?
1516 .build()?;
1517
1518 assert_optimized_plan_equal!(
1519 plan,
1520 @r"
1521 Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d
1522 Projection: test.a + Int32(1) AS a, Int32(0) AS d
1523 TableScan: test projection=[a]
1524 "
1525 )
1526 }
1527
1528 #[test]
1531 fn test_user_defined_logical_plan_node() -> Result<()> {
1532 let table_scan = test_table_scan()?;
1533 let custom_plan = LogicalPlan::Extension(Extension {
1534 node: Arc::new(NoOpUserDefined::new(
1535 Arc::clone(table_scan.schema()),
1536 Arc::new(table_scan.clone()),
1537 )),
1538 });
1539 let plan = LogicalPlanBuilder::from(custom_plan)
1540 .project(vec![col("a"), lit(0).alias("d")])?
1541 .build()?;
1542
1543 assert_optimized_plan_equal!(
1544 plan,
1545 @r"
1546 Projection: test.a, Int32(0) AS d
1547 NoOpUserDefined
1548 TableScan: test projection=[a]
1549 "
1550 )
1551 }
1552
1553 #[test]
1558 fn test_user_defined_logical_plan_node2() -> Result<()> {
1559 let table_scan = test_table_scan()?;
1560 let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
1561 let custom_plan = LogicalPlan::Extension(Extension {
1562 node: Arc::new(
1563 NoOpUserDefined::new(
1564 Arc::clone(table_scan.schema()),
1565 Arc::new(table_scan.clone()),
1566 )
1567 .with_exprs(exprs),
1568 ),
1569 });
1570 let plan = LogicalPlanBuilder::from(custom_plan)
1571 .project(vec![col("a"), lit(0).alias("d")])?
1572 .build()?;
1573
1574 assert_optimized_plan_equal!(
1575 plan,
1576 @r"
1577 Projection: test.a, Int32(0) AS d
1578 NoOpUserDefined
1579 TableScan: test projection=[a, b]
1580 "
1581 )
1582 }
1583
1584 #[test]
1590 fn test_user_defined_logical_plan_node3() -> Result<()> {
1591 let table_scan = test_table_scan()?;
1592 let left_expr = Expr::Column(Column::from_qualified_name("b"));
1593 let right_expr = Expr::Column(Column::from_qualified_name("c"));
1594 let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
1595 Box::new(left_expr),
1596 Operator::Plus,
1597 Box::new(right_expr),
1598 ));
1599 let exprs = vec![binary_expr];
1600 let custom_plan = LogicalPlan::Extension(Extension {
1601 node: Arc::new(
1602 NoOpUserDefined::new(
1603 Arc::clone(table_scan.schema()),
1604 Arc::new(table_scan.clone()),
1605 )
1606 .with_exprs(exprs),
1607 ),
1608 });
1609 let plan = LogicalPlanBuilder::from(custom_plan)
1610 .project(vec![col("a"), lit(0).alias("d")])?
1611 .build()?;
1612
1613 assert_optimized_plan_equal!(
1614 plan,
1615 @r"
1616 Projection: test.a, Int32(0) AS d
1617 NoOpUserDefined
1618 TableScan: test projection=[a, b, c]
1619 "
1620 )
1621 }
1622
1623 #[test]
1628 fn test_user_defined_logical_plan_node4() -> Result<()> {
1629 let left_table = test_table_scan_with_name("l")?;
1630 let right_table = test_table_scan_with_name("r")?;
1631 let custom_plan = LogicalPlan::Extension(Extension {
1632 node: Arc::new(UserDefinedCrossJoin::new(
1633 Arc::new(left_table),
1634 Arc::new(right_table),
1635 )),
1636 });
1637 let plan = LogicalPlanBuilder::from(custom_plan)
1638 .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
1639 .build()?;
1640
1641 assert_optimized_plan_equal!(
1642 plan,
1643 @r"
1644 Projection: l.a, l.c, r.a, Int32(0) AS d
1645 UserDefinedCrossJoin
1646 TableScan: l projection=[a, c]
1647 TableScan: r projection=[a]
1648 "
1649 )
1650 }
1651
1652 #[test]
1653 fn aggregate_no_group_by() -> Result<()> {
1654 let table_scan = test_table_scan()?;
1655
1656 let plan = LogicalPlanBuilder::from(table_scan)
1657 .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1658 .build()?;
1659
1660 assert_optimized_plan_equal!(
1661 plan,
1662 @r"
1663 Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1664 TableScan: test projection=[b]
1665 "
1666 )
1667 }
1668
1669 #[test]
1670 fn aggregate_group_by() -> Result<()> {
1671 let table_scan = test_table_scan()?;
1672
1673 let plan = LogicalPlanBuilder::from(table_scan)
1674 .aggregate(vec![col("c")], vec![max(col("b"))])?
1675 .build()?;
1676
1677 assert_optimized_plan_equal!(
1678 plan,
1679 @r"
1680 Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]
1681 TableScan: test projection=[b, c]
1682 "
1683 )
1684 }
1685
1686 #[test]
1687 fn aggregate_group_by_with_table_alias() -> Result<()> {
1688 let table_scan = test_table_scan()?;
1689
1690 let plan = LogicalPlanBuilder::from(table_scan)
1691 .alias("a")?
1692 .aggregate(vec![col("c")], vec![max(col("b"))])?
1693 .build()?;
1694
1695 assert_optimized_plan_equal!(
1696 plan,
1697 @r"
1698 Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]
1699 SubqueryAlias: a
1700 TableScan: test projection=[b, c]
1701 "
1702 )
1703 }
1704
1705 #[test]
1706 fn aggregate_no_group_by_with_filter() -> Result<()> {
1707 let table_scan = test_table_scan()?;
1708
1709 let plan = LogicalPlanBuilder::from(table_scan)
1710 .filter(col("c").gt(lit(1)))?
1711 .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1712 .build()?;
1713
1714 assert_optimized_plan_equal!(
1715 plan,
1716 @r"
1717 Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1718 Projection: test.b
1719 Filter: test.c > Int32(1)
1720 TableScan: test projection=[b, c]
1721 "
1722 )
1723 }
1724
1725 #[test]
1726 fn aggregate_with_periods() -> Result<()> {
1727 let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]);
1728
1729 let plan = table_scan(Some("m4"), &schema, None)?
1736 .aggregate(
1737 Vec::<Expr>::new(),
1738 vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")],
1739 )?
1740 .project([col(Column::new_unqualified("tag.one"))])?
1741 .build()?;
1742
1743 assert_optimized_plan_equal!(
1744 plan,
1745 @r"
1746 Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]
1747 TableScan: m4 projection=[tag.one]
1748 "
1749 )
1750 }
1751
1752 #[test]
1753 fn redundant_project() -> Result<()> {
1754 let table_scan = test_table_scan()?;
1755
1756 let plan = LogicalPlanBuilder::from(table_scan)
1757 .project(vec![col("a"), col("b"), col("c")])?
1758 .project(vec![col("a"), col("c"), col("b")])?
1759 .build()?;
1760 assert_optimized_plan_equal!(
1761 plan,
1762 @r"
1763 Projection: test.a, test.c, test.b
1764 TableScan: test projection=[a, b, c]
1765 "
1766 )
1767 }
1768
1769 #[test]
1770 fn reorder_scan() -> Result<()> {
1771 let schema = Schema::new(test_table_scan_fields());
1772
1773 let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?;
1774 assert_optimized_plan_equal!(
1775 plan,
1776 @"TableScan: test projection=[b, a, c]"
1777 )
1778 }
1779
1780 #[test]
1781 fn reorder_scan_projection() -> Result<()> {
1782 let schema = Schema::new(test_table_scan_fields());
1783
1784 let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?
1785 .project(vec![col("a"), col("b")])?
1786 .build()?;
1787 assert_optimized_plan_equal!(
1788 plan,
1789 @r"
1790 Projection: test.a, test.b
1791 TableScan: test projection=[b, a]
1792 "
1793 )
1794 }
1795
1796 #[test]
1797 fn reorder_projection() -> Result<()> {
1798 let table_scan = test_table_scan()?;
1799
1800 let plan = LogicalPlanBuilder::from(table_scan)
1801 .project(vec![col("c"), col("b"), col("a")])?
1802 .build()?;
1803 assert_optimized_plan_equal!(
1804 plan,
1805 @r"
1806 Projection: test.c, test.b, test.a
1807 TableScan: test projection=[a, b, c]
1808 "
1809 )
1810 }
1811
1812 #[test]
1813 fn noncontinuous_redundant_projection() -> Result<()> {
1814 let table_scan = test_table_scan()?;
1815
1816 let plan = LogicalPlanBuilder::from(table_scan)
1817 .project(vec![col("c"), col("b"), col("a")])?
1818 .filter(col("c").gt(lit(1)))?
1819 .project(vec![col("c"), col("a"), col("b")])?
1820 .filter(col("b").gt(lit(1)))?
1821 .filter(col("a").gt(lit(1)))?
1822 .project(vec![col("a"), col("c"), col("b")])?
1823 .build()?;
1824 assert_optimized_plan_equal!(
1825 plan,
1826 @r"
1827 Projection: test.a, test.c, test.b
1828 Filter: test.a > Int32(1)
1829 Filter: test.b > Int32(1)
1830 Projection: test.c, test.a, test.b
1831 Filter: test.c > Int32(1)
1832 Projection: test.c, test.b, test.a
1833 TableScan: test projection=[a, b, c]
1834 "
1835 )
1836 }
1837
1838 #[test]
1839 fn join_schema_trim_full_join_column_projection() -> Result<()> {
1840 let table_scan = test_table_scan()?;
1841
1842 let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1843 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1844
1845 let plan = LogicalPlanBuilder::from(table_scan)
1846 .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1847 .project(vec![col("a"), col("b"), col("c1")])?
1848 .build()?;
1849
1850 let optimized_plan = optimize(plan)?;
1851
1852 assert_snapshot!(
1854 optimized_plan.clone(),
1855 @r"
1856 Left Join: test.a = test2.c1
1857 TableScan: test projection=[a, b]
1858 TableScan: test2 projection=[c1]
1859 "
1860 );
1861
1862 let optimized_join = optimized_plan;
1864 assert_eq!(
1865 **optimized_join.schema(),
1866 DFSchema::new_with_metadata(
1867 vec![
1868 (
1869 Some("test".into()),
1870 Arc::new(Field::new("a", DataType::UInt32, false))
1871 ),
1872 (
1873 Some("test".into()),
1874 Arc::new(Field::new("b", DataType::UInt32, false))
1875 ),
1876 (
1877 Some("test2".into()),
1878 Arc::new(Field::new("c1", DataType::UInt32, true))
1879 ),
1880 ],
1881 HashMap::new()
1882 )?,
1883 );
1884
1885 Ok(())
1886 }
1887
1888 #[test]
1889 fn join_schema_trim_partial_join_column_projection() -> Result<()> {
1890 let table_scan = test_table_scan()?;
1893
1894 let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1895 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1896
1897 let plan = LogicalPlanBuilder::from(table_scan)
1898 .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1899 .project(vec![col("a"), col("b")])?
1902 .build()?;
1903
1904 let optimized_plan = optimize(plan)?;
1905
1906 assert_snapshot!(
1908 optimized_plan.clone(),
1909 @r"
1910 Projection: test.a, test.b
1911 Left Join: test.a = test2.c1
1912 TableScan: test projection=[a, b]
1913 TableScan: test2 projection=[c1]
1914 "
1915 );
1916
1917 let optimized_join = optimized_plan.inputs()[0];
1919 assert_eq!(
1920 **optimized_join.schema(),
1921 DFSchema::new_with_metadata(
1922 vec![
1923 (
1924 Some("test".into()),
1925 Arc::new(Field::new("a", DataType::UInt32, false))
1926 ),
1927 (
1928 Some("test".into()),
1929 Arc::new(Field::new("b", DataType::UInt32, false))
1930 ),
1931 (
1932 Some("test2".into()),
1933 Arc::new(Field::new("c1", DataType::UInt32, true))
1934 ),
1935 ],
1936 HashMap::new()
1937 )?,
1938 );
1939
1940 Ok(())
1941 }
1942
1943 #[test]
1944 fn join_schema_trim_using_join() -> Result<()> {
1945 let table_scan = test_table_scan()?;
1948
1949 let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1950 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1951
1952 let plan = LogicalPlanBuilder::from(table_scan)
1953 .join_using(table2_scan, JoinType::Left, vec!["a".into()])?
1954 .project(vec![col("a"), col("b")])?
1955 .build()?;
1956
1957 let optimized_plan = optimize(plan)?;
1958
1959 assert_snapshot!(
1961 optimized_plan.clone(),
1962 @r"
1963 Projection: test.a, test.b
1964 Left Join: Using test.a = test2.a
1965 TableScan: test projection=[a, b]
1966 TableScan: test2 projection=[a]
1967 "
1968 );
1969
1970 let optimized_join = optimized_plan.inputs()[0];
1972 assert_eq!(
1973 **optimized_join.schema(),
1974 DFSchema::new_with_metadata(
1975 vec![
1976 (
1977 Some("test".into()),
1978 Arc::new(Field::new("a", DataType::UInt32, false))
1979 ),
1980 (
1981 Some("test".into()),
1982 Arc::new(Field::new("b", DataType::UInt32, false))
1983 ),
1984 (
1985 Some("test2".into()),
1986 Arc::new(Field::new("a", DataType::UInt32, true))
1987 ),
1988 ],
1989 HashMap::new()
1990 )?,
1991 );
1992
1993 Ok(())
1994 }
1995
1996 #[test]
1997 fn cast() -> Result<()> {
1998 let table_scan = test_table_scan()?;
1999
2000 let plan = LogicalPlanBuilder::from(table_scan)
2001 .project(vec![Expr::Cast(Cast::new(
2002 Box::new(col("c")),
2003 DataType::Float64,
2004 ))])?
2005 .build()?;
2006
2007 assert_optimized_plan_equal!(
2008 plan,
2009 @r"
2010 Projection: CAST(test.c AS Float64)
2011 TableScan: test projection=[c]
2012 "
2013 )
2014 }
2015
2016 #[test]
2017 fn table_scan_projected_schema() -> Result<()> {
2018 let table_scan = test_table_scan()?;
2019 let plan = LogicalPlanBuilder::from(test_table_scan()?)
2020 .project(vec![col("a"), col("b")])?
2021 .build()?;
2022
2023 assert_eq!(3, table_scan.schema().fields().len());
2024 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2025 assert_fields_eq(&plan, vec!["a", "b"]);
2026
2027 assert_optimized_plan_equal!(
2028 plan,
2029 @"TableScan: test projection=[a, b]"
2030 )
2031 }
2032
2033 #[test]
2034 fn table_scan_projected_schema_non_qualified_relation() -> Result<()> {
2035 let table_scan = test_table_scan()?;
2036 let input_schema = table_scan.schema();
2037 assert_eq!(3, input_schema.fields().len());
2038 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2039
2040 let expr = vec![col("test.a"), col("test.b")];
2044 let plan =
2045 LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?);
2046
2047 assert_fields_eq(&plan, vec!["a", "b"]);
2048
2049 assert_optimized_plan_equal!(
2050 plan,
2051 @"TableScan: test projection=[a, b]"
2052 )
2053 }
2054
2055 #[test]
2056 fn table_limit() -> Result<()> {
2057 let table_scan = test_table_scan()?;
2058 assert_eq!(3, table_scan.schema().fields().len());
2059 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2060
2061 let plan = LogicalPlanBuilder::from(table_scan)
2062 .project(vec![col("c"), col("a")])?
2063 .limit(0, Some(5))?
2064 .build()?;
2065
2066 assert_fields_eq(&plan, vec!["c", "a"]);
2067
2068 assert_optimized_plan_equal!(
2069 plan,
2070 @r"
2071 Limit: skip=0, fetch=5
2072 Projection: test.c, test.a
2073 TableScan: test projection=[a, c]
2074 "
2075 )
2076 }
2077
2078 #[test]
2079 fn table_scan_without_projection() -> Result<()> {
2080 let table_scan = test_table_scan()?;
2081 let plan = LogicalPlanBuilder::from(table_scan).build()?;
2082 assert_optimized_plan_equal!(
2084 plan,
2085 @"TableScan: test projection=[a, b, c]"
2086 )
2087 }
2088
2089 #[test]
2090 fn table_scan_with_literal_projection() -> Result<()> {
2091 let table_scan = test_table_scan()?;
2092 let plan = LogicalPlanBuilder::from(table_scan)
2093 .project(vec![lit(1_i64), lit(2_i64)])?
2094 .build()?;
2095 assert_optimized_plan_equal!(
2096 plan,
2097 @r"
2098 Projection: Int64(1), Int64(2)
2099 TableScan: test projection=[]
2100 "
2101 )
2102 }
2103
2104 #[test]
2106 fn table_unused_column() -> Result<()> {
2107 let table_scan = test_table_scan()?;
2108 assert_eq!(3, table_scan.schema().fields().len());
2109 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2110
2111 let plan = LogicalPlanBuilder::from(table_scan)
2113 .project(vec![col("c"), col("a"), col("b")])?
2114 .filter(col("c").gt(lit(1)))?
2115 .aggregate(vec![col("c")], vec![max(col("a"))])?
2116 .build()?;
2117
2118 assert_fields_eq(&plan, vec!["c", "max(test.a)"]);
2119
2120 let plan = optimize(plan).expect("failed to optimize plan");
2121 assert_optimized_plan_equal!(
2122 plan,
2123 @r"
2124 Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]
2125 Filter: test.c > Int32(1)
2126 Projection: test.c, test.a
2127 TableScan: test projection=[a, c]
2128 "
2129 )
2130 }
2131
2132 #[test]
2134 fn table_unused_projection() -> Result<()> {
2135 let table_scan = test_table_scan()?;
2136 assert_eq!(3, table_scan.schema().fields().len());
2137 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2138
2139 let plan = LogicalPlanBuilder::from(table_scan)
2141 .project(vec![col("b")])?
2142 .project(vec![lit(1).alias("a")])?
2143 .build()?;
2144
2145 assert_fields_eq(&plan, vec!["a"]);
2146
2147 assert_optimized_plan_equal!(
2148 plan,
2149 @r"
2150 Projection: Int32(1) AS a
2151 TableScan: test projection=[]
2152 "
2153 )
2154 }
2155
2156 #[test]
2157 fn table_full_filter_pushdown() -> Result<()> {
2158 let schema = Schema::new(test_table_scan_fields());
2159
2160 let table_scan = table_scan_with_filters(
2161 Some("test"),
2162 &schema,
2163 None,
2164 vec![col("b").eq(lit(1))],
2165 )?
2166 .build()?;
2167 assert_eq!(3, table_scan.schema().fields().len());
2168 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2169
2170 let plan = LogicalPlanBuilder::from(table_scan)
2172 .project(vec![col("b")])?
2173 .project(vec![lit(1).alias("a")])?
2174 .build()?;
2175
2176 assert_fields_eq(&plan, vec!["a"]);
2177
2178 assert_optimized_plan_equal!(
2179 plan,
2180 @r"
2181 Projection: Int32(1) AS a
2182 TableScan: test projection=[], full_filters=[b = Int32(1)]
2183 "
2184 )
2185 }
2186
2187 #[test]
2189 fn test_double_optimization() -> Result<()> {
2190 let table_scan = test_table_scan()?;
2191
2192 let plan = LogicalPlanBuilder::from(table_scan)
2193 .project(vec![col("b")])?
2194 .project(vec![lit(1).alias("a")])?
2195 .build()?;
2196
2197 let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
2198 let optimized_plan2 =
2199 optimize(optimized_plan1.clone()).expect("failed to optimize plan");
2200
2201 let formatted_plan1 = format!("{optimized_plan1:?}");
2202 let formatted_plan2 = format!("{optimized_plan2:?}");
2203 assert_eq!(formatted_plan1, formatted_plan2);
2204 Ok(())
2205 }
2206
2207 #[test]
2209 fn table_unused_aggregate() -> Result<()> {
2210 let table_scan = test_table_scan()?;
2211 assert_eq!(3, table_scan.schema().fields().len());
2212 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2213
2214 let plan = LogicalPlanBuilder::from(table_scan)
2216 .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
2217 .filter(col("c").gt(lit(1)))?
2218 .project(vec![col("c"), col("a"), col("max(test.b)")])?
2219 .build()?;
2220
2221 assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]);
2222
2223 assert_optimized_plan_equal!(
2224 plan,
2225 @r"
2226 Projection: test.c, test.a, max(test.b)
2227 Filter: test.c > Int32(1)
2228 Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]
2229 TableScan: test projection=[a, b, c]
2230 "
2231 )
2232 }
2233
2234 #[test]
2235 fn aggregate_filter_pushdown() -> Result<()> {
2236 let table_scan = test_table_scan()?;
2237 let aggr_with_filter = count_udaf()
2238 .call(vec![col("b")])
2239 .filter(col("c").gt(lit(42)))
2240 .build()?;
2241 let plan = LogicalPlanBuilder::from(table_scan)
2242 .aggregate(
2243 vec![col("a")],
2244 vec![count(col("b")), aggr_with_filter.alias("count2")],
2245 )?
2246 .build()?;
2247
2248 assert_optimized_plan_equal!(
2249 plan,
2250 @r"
2251 Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]
2252 TableScan: test projection=[a, b, c]
2253 "
2254 )
2255 }
2256
2257 #[test]
2258 fn pushdown_through_distinct() -> Result<()> {
2259 let table_scan = test_table_scan()?;
2260
2261 let plan = LogicalPlanBuilder::from(table_scan)
2262 .project(vec![col("a"), col("b")])?
2263 .distinct()?
2264 .project(vec![col("a")])?
2265 .build()?;
2266
2267 assert_optimized_plan_equal!(
2268 plan,
2269 @r"
2270 Projection: test.a
2271 Distinct:
2272 TableScan: test projection=[a, b]
2273 "
2274 )
2275 }
2276
2277 #[test]
2278 fn test_window() -> Result<()> {
2279 let table_scan = test_table_scan()?;
2280
2281 let max1 = Expr::from(expr::WindowFunction::new(
2282 WindowFunctionDefinition::AggregateUDF(max_udaf()),
2283 vec![col("test.a")],
2284 ))
2285 .partition_by(vec![col("test.b")])
2286 .build()
2287 .unwrap();
2288
2289 let max2 = Expr::from(expr::WindowFunction::new(
2290 WindowFunctionDefinition::AggregateUDF(max_udaf()),
2291 vec![col("test.b")],
2292 ));
2293 let col1 = col(max1.schema_name().to_string());
2294 let col2 = col(max2.schema_name().to_string());
2295
2296 let plan = LogicalPlanBuilder::from(table_scan)
2297 .window(vec![max1])?
2298 .window(vec![max2])?
2299 .project(vec![col1, col2])?
2300 .build()?;
2301
2302 assert_optimized_plan_equal!(
2303 plan,
2304 @r"
2305 Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2306 WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2307 Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2308 WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2309 TableScan: test projection=[a, b]
2310 "
2311 )
2312 }
2313
2314 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
2315
2316 fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
2317 let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
2318 let optimized_plan =
2319 optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
2320 Ok(optimized_plan)
2321 }
2322}