1mod required_indices;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24use std::collections::HashSet;
25use std::sync::Arc;
26
27use datafusion_common::{
28 get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column,
29 DFSchema, HashMap, JoinType, Result,
30};
31use datafusion_expr::expr::Alias;
32use datafusion_expr::{
33 logical_plan::LogicalPlan, Aggregate, Distinct, EmptyRelation, Expr, Projection,
34 TableScan, Unnest, Window,
35};
36
37use crate::optimize_projections::required_indices::RequiredIndices;
38use crate::utils::NamePreserver;
39use datafusion_common::tree_node::{
40 Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
41};
42
43#[derive(Default, Debug)]
77pub struct OptimizeProjections {}
78
79impl OptimizeProjections {
80 #[allow(missing_docs)]
81 pub fn new() -> Self {
82 Self {}
83 }
84}
85
86impl OptimizerRule for OptimizeProjections {
87 fn name(&self) -> &str {
88 "optimize_projections"
89 }
90
91 fn apply_order(&self) -> Option<ApplyOrder> {
92 None
93 }
94
95 fn supports_rewrite(&self) -> bool {
96 true
97 }
98
99 fn rewrite(
100 &self,
101 plan: LogicalPlan,
102 config: &dyn OptimizerConfig,
103 ) -> Result<Transformed<LogicalPlan>> {
104 let indices = RequiredIndices::new_for_all_exprs(&plan);
106 optimize_projections(plan, config, indices)
107 }
108}
109
110#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
130fn optimize_projections(
131 plan: LogicalPlan,
132 config: &dyn OptimizerConfig,
133 indices: RequiredIndices,
134) -> Result<Transformed<LogicalPlan>> {
135 match plan {
138 LogicalPlan::Projection(proj) => {
139 return merge_consecutive_projections(proj)?.transform_data(|proj| {
140 rewrite_projection_given_requirements(proj, config, &indices)
141 })
142 }
143 LogicalPlan::Aggregate(aggregate) => {
144 let n_group_exprs = aggregate.group_expr_len()?;
146 let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
149
150 let group_by_expr_existing = aggregate
152 .group_expr
153 .iter()
154 .map(|group_by_expr| group_by_expr.schema_name().to_string())
155 .collect::<Vec<_>>();
156
157 let new_group_bys = if let Some(simplest_groupby_indices) =
158 get_required_group_by_exprs_indices(
159 aggregate.input.schema(),
160 &group_by_expr_existing,
161 ) {
162 group_by_reqs
166 .append(&simplest_groupby_indices)
167 .get_at_indices(&aggregate.group_expr)
168 } else {
169 aggregate.group_expr
170 };
171
172 let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr);
175
176 if new_group_bys.is_empty() && new_aggr_expr.is_empty() {
177 return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
179 EmptyRelation {
180 produce_one_row: true,
181 schema: Arc::new(DFSchema::empty()),
182 },
183 )));
184 }
185
186 let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
187 let schema = aggregate.input.schema();
188 let necessary_indices =
189 RequiredIndices::new().with_exprs(schema, all_exprs_iter);
190 let necessary_exprs = necessary_indices.get_required_exprs(schema);
191
192 return optimize_projections(
193 Arc::unwrap_or_clone(aggregate.input),
194 config,
195 necessary_indices,
196 )?
197 .transform_data(|aggregate_input| {
198 add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)
203 })?
204 .map_data(|aggregate_input| {
205 Aggregate::try_new(
208 Arc::new(aggregate_input),
209 new_group_bys,
210 new_aggr_expr,
211 )
212 .map(LogicalPlan::Aggregate)
213 });
214 }
215 LogicalPlan::Window(window) => {
216 let input_schema = Arc::clone(window.input.schema());
217 let n_input_fields = input_schema.fields().len();
219 let (child_reqs, window_reqs) = indices.split_off(n_input_fields);
222
223 let new_window_expr = window_reqs.get_at_indices(&window.window_expr);
226
227 let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
230
231 return optimize_projections(
232 Arc::unwrap_or_clone(window.input),
233 config,
234 required_indices.clone(),
235 )?
236 .transform_data(|window_child| {
237 if new_window_expr.is_empty() {
238 Ok(Transformed::no(window_child))
240 } else {
241 let required_exprs =
245 required_indices.get_required_exprs(&input_schema);
246 let window_child =
247 add_projection_on_top_if_helpful(window_child, required_exprs)?
248 .data;
249 Window::try_new(new_window_expr, Arc::new(window_child))
250 .map(LogicalPlan::Window)
251 .map(Transformed::yes)
252 }
253 });
254 }
255 LogicalPlan::TableScan(table_scan) => {
256 let TableScan {
257 table_name,
258 source,
259 projection,
260 filters,
261 fetch,
262 projected_schema: _,
263 } = table_scan;
264
265 let projection = match &projection {
268 Some(projection) => indices.into_mapped_indices(|idx| projection[idx]),
269 None => indices.into_inner(),
270 };
271 return TableScan::try_new(
272 table_name,
273 source,
274 Some(projection),
275 filters,
276 fetch,
277 )
278 .map(LogicalPlan::TableScan)
279 .map(Transformed::yes);
280 }
281 _ => {}
283 };
284
285 let mut child_required_indices: Vec<RequiredIndices> = match &plan {
288 LogicalPlan::Sort(_)
289 | LogicalPlan::Filter(_)
290 | LogicalPlan::Repartition(_)
291 | LogicalPlan::Union(_)
292 | LogicalPlan::SubqueryAlias(_)
293 | LogicalPlan::Distinct(Distinct::On(_)) => {
294 plan.inputs()
299 .into_iter()
300 .map(|input| {
301 indices
302 .clone()
303 .with_projection_beneficial()
304 .with_plan_exprs(&plan, input.schema())
305 })
306 .collect::<Result<_>>()?
307 }
308 LogicalPlan::Limit(_) => {
309 plan.inputs()
314 .into_iter()
315 .map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
316 .collect::<Result<_>>()?
317 }
318 LogicalPlan::Copy(_)
319 | LogicalPlan::Ddl(_)
320 | LogicalPlan::Dml(_)
321 | LogicalPlan::Explain(_)
322 | LogicalPlan::Analyze(_)
323 | LogicalPlan::Subquery(_)
324 | LogicalPlan::Statement(_)
325 | LogicalPlan::Distinct(Distinct::All(_)) => {
326 plan.inputs()
332 .into_iter()
333 .map(RequiredIndices::new_for_all_exprs)
334 .collect()
335 }
336 LogicalPlan::Extension(extension) => {
337 let Some(necessary_children_indices) =
338 extension.node.necessary_children_exprs(indices.indices())
339 else {
340 return Ok(Transformed::no(plan));
342 };
343 let children = extension.node.inputs();
344 if children.len() != necessary_children_indices.len() {
345 return internal_err!("Inconsistent length between children and necessary children indices. \
346 Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \
347 consistent with actual children length for the node.");
348 }
349 children
350 .into_iter()
351 .zip(necessary_children_indices)
352 .map(|(child, necessary_indices)| {
353 RequiredIndices::new_from_indices(necessary_indices)
354 .with_plan_exprs(&plan, child.schema())
355 })
356 .collect::<Result<Vec<_>>>()?
357 }
358 LogicalPlan::EmptyRelation(_)
359 | LogicalPlan::Values(_)
360 | LogicalPlan::DescribeTable(_) => {
361 return Ok(Transformed::no(plan));
363 }
364 LogicalPlan::RecursiveQuery(recursive) => {
365 if plan_contains_other_subqueries(
369 recursive.static_term.as_ref(),
370 &recursive.name,
371 ) || plan_contains_other_subqueries(
372 recursive.recursive_term.as_ref(),
373 &recursive.name,
374 ) {
375 return Ok(Transformed::no(plan));
376 }
377
378 plan.inputs()
379 .into_iter()
380 .map(|input| {
381 indices
382 .clone()
383 .with_projection_beneficial()
384 .with_plan_exprs(&plan, input.schema())
385 })
386 .collect::<Result<Vec<_>>>()?
387 }
388 LogicalPlan::Join(join) => {
389 let left_len = join.left.schema().fields().len();
390 let (left_req_indices, right_req_indices) =
391 split_join_requirements(left_len, indices, &join.join_type);
392 let left_indices =
393 left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
394 let right_indices =
395 right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
396 vec![
399 left_indices.with_projection_beneficial(),
400 right_indices.with_projection_beneficial(),
401 ]
402 }
403 LogicalPlan::Projection(_)
405 | LogicalPlan::Aggregate(_)
406 | LogicalPlan::Window(_)
407 | LogicalPlan::TableScan(_) => {
408 return internal_err!(
409 "OptimizeProjection: should have handled in the match statement above"
410 );
411 }
412 LogicalPlan::Unnest(Unnest {
413 input,
414 dependency_indices,
415 ..
416 }) => {
417 let required_indices =
419 RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
420
421 let mut additional_necessary_child_indices = Vec::new();
423 indices.indices().iter().for_each(|idx| {
424 if let Some(index) = dependency_indices.get(*idx) {
425 additional_necessary_child_indices.push(*index);
426 }
427 });
428 vec![required_indices.append(&additional_necessary_child_indices)]
429 }
430 };
431
432 child_required_indices.reverse();
435 if child_required_indices.len() != plan.inputs().len() {
436 return internal_err!(
437 "OptimizeProjection: child_required_indices length mismatch with plan inputs"
438 );
439 }
440
441 let transformed_plan = plan.map_children(|child| {
443 let required_indices = child_required_indices.pop().ok_or_else(|| {
444 internal_datafusion_err!(
445 "Unexpected number of required_indices in OptimizeProjections rule"
446 )
447 })?;
448
449 let projection_beneficial = required_indices.projection_beneficial();
450 let project_exprs = required_indices.get_required_exprs(child.schema());
451
452 optimize_projections(child, config, required_indices)?.transform_data(
453 |new_input| {
454 if projection_beneficial {
455 add_projection_on_top_if_helpful(new_input, project_exprs)
456 } else {
457 Ok(Transformed::no(new_input))
458 }
459 },
460 )
461 })?;
462
463 if transformed_plan.transformed {
465 transformed_plan.map_data(|plan| plan.recompute_schema())
466 } else {
467 Ok(transformed_plan)
468 }
469}
470
471fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Projection>> {
504 let Projection {
505 expr,
506 input,
507 schema,
508 ..
509 } = proj;
510 let LogicalPlan::Projection(prev_projection) = input.as_ref() else {
511 return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
512 };
513
514 if prev_projection.expr == expr {
517 return Projection::try_new_with_schema(
518 expr,
519 Arc::clone(&prev_projection.input),
520 schema,
521 )
522 .map(Transformed::yes);
523 }
524
525 let mut column_referral_map = HashMap::<&Column, usize>::new();
527 expr.iter()
528 .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map));
529
530 if column_referral_map.into_iter().any(|(col, usage)| {
534 usage > 1
535 && !is_expr_trivial(
536 &prev_projection.expr
537 [prev_projection.schema.index_of_column(col).unwrap()],
538 )
539 }) {
540 return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no);
542 }
543
544 let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else {
545 unreachable!();
547 };
548
549 let name_preserver = NamePreserver::new_for_projection();
552 let mut original_names = vec![];
553 let new_exprs = expr.map_elements(|expr| {
554 original_names.push(name_preserver.save(&expr));
555
556 match expr {
558 Expr::Alias(Alias {
559 expr,
560 relation,
561 name,
562 metadata,
563 }) => rewrite_expr(*expr, &prev_projection).map(|result| {
564 result.update_data(|expr| {
565 Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata))
566 })
567 }),
568 e => rewrite_expr(e, &prev_projection),
569 }
570 })?;
571
572 if new_exprs.transformed {
575 let new_exprs = new_exprs
577 .data
578 .into_iter()
579 .zip(original_names)
580 .map(|(expr, original_name)| original_name.restore(expr))
581 .collect::<Vec<_>>();
582 Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
583 } else {
584 let input = Arc::new(LogicalPlan::Projection(prev_projection));
586 Projection::try_new_with_schema(new_exprs.data, input, schema)
587 .map(Transformed::no)
588 }
589}
590
591fn is_expr_trivial(expr: &Expr) -> bool {
593 matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
594}
595
596fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
642 expr.transform_up(|expr| {
643 match expr {
644 Expr::Alias(alias) => {
646 match alias
647 .metadata
648 .as_ref()
649 .map(|h| h.is_empty())
650 .unwrap_or(true)
651 {
652 true => Ok(Transformed::yes(*alias.expr)),
653 false => Ok(Transformed::no(Expr::Alias(alias))),
654 }
655 }
656 Expr::Column(col) => {
657 let idx = input.schema.index_of_column(&col)?;
659 let input_expr = input.expr[idx].clone().unalias_nested().data;
667 Ok(Transformed::yes(input_expr))
668 }
669 _ => Ok(Transformed::no(expr)),
671 }
672 })
673}
674
675fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
684 expr.apply(|expr| {
686 match expr {
687 Expr::OuterReferenceColumn(_, col) => {
688 columns.insert(col);
689 }
690 Expr::ScalarSubquery(subquery) => {
691 outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
692 }
693 Expr::Exists(exists) => {
694 outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns);
695 }
696 Expr::InSubquery(insubquery) => {
697 outer_columns_helper_multi(
698 &insubquery.subquery.outer_ref_columns,
699 columns,
700 );
701 }
702 _ => {}
703 };
704 Ok(TreeNodeRecursion::Continue)
705 })
706 .unwrap();
708}
709
710fn outer_columns_helper_multi<'a, 'b>(
719 exprs: impl IntoIterator<Item = &'a Expr>,
720 columns: &'b mut HashSet<&'a Column>,
721) {
722 exprs.into_iter().for_each(|e| outer_columns(e, columns));
723}
724
725fn split_join_requirements(
755 left_len: usize,
756 indices: RequiredIndices,
757 join_type: &JoinType,
758) -> (RequiredIndices, RequiredIndices) {
759 match join_type {
760 JoinType::Inner
762 | JoinType::Left
763 | JoinType::Right
764 | JoinType::Full
765 | JoinType::LeftMark
766 | JoinType::RightMark => {
767 indices.split_off(left_len)
770 }
771 JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
773 JoinType::RightSemi | JoinType::RightAnti => (RequiredIndices::new(), indices),
776 }
777}
778
779fn add_projection_on_top_if_helpful(
797 plan: LogicalPlan,
798 project_exprs: Vec<Expr>,
799) -> Result<Transformed<LogicalPlan>> {
800 if project_exprs.len() >= plan.schema().fields().len() {
802 Ok(Transformed::no(plan))
803 } else {
804 Projection::try_new(project_exprs, Arc::new(plan))
805 .map(LogicalPlan::Projection)
806 .map(Transformed::yes)
807 }
808}
809
810fn rewrite_projection_given_requirements(
828 proj: Projection,
829 config: &dyn OptimizerConfig,
830 indices: &RequiredIndices,
831) -> Result<Transformed<LogicalPlan>> {
832 let Projection { expr, input, .. } = proj;
833
834 let exprs_used = indices.get_at_indices(&expr);
835
836 let required_indices =
837 RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
838
839 optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?
842 .transform_data(|input| {
843 if is_projection_unnecessary(&input, &exprs_used)? {
844 Ok(Transformed::yes(input))
845 } else {
846 Projection::try_new(exprs_used, Arc::new(input))
847 .map(LogicalPlan::Projection)
848 .map(Transformed::yes)
849 }
850 })
851}
852
853pub fn is_projection_unnecessary(
857 input: &LogicalPlan,
858 proj_exprs: &[Expr],
859) -> Result<bool> {
860 if proj_exprs.len() != input.schema().fields().len() {
862 return Ok(false);
863 }
864 Ok(input.schema().iter().zip(proj_exprs.iter()).all(
865 |((field_relation, field_name), expr)| {
866 if let Expr::Column(col) = expr {
868 col.relation.as_ref() == field_relation && col.name.eq(field_name.name())
869 } else {
870 false
871 }
872 },
873 ))
874}
875
876fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
882 if let LogicalPlan::SubqueryAlias(alias) = plan {
883 if alias.alias.table() != cte_name
884 && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
885 {
886 return true;
887 }
888 }
889
890 let mut found = false;
891 plan.apply_expressions(|expr| {
892 if expr_contains_subquery(expr) {
893 found = true;
894 Ok(TreeNodeRecursion::Stop)
895 } else {
896 Ok(TreeNodeRecursion::Continue)
897 }
898 })
899 .expect("expression traversal never fails");
900 if found {
901 return true;
902 }
903
904 plan.inputs()
905 .into_iter()
906 .any(|child| plan_contains_other_subqueries(child, cte_name))
907}
908
909fn expr_contains_subquery(expr: &Expr) -> bool {
910 expr.exists(|e| match e {
911 Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
912 _ => Ok(false),
913 })
914 .unwrap()
916}
917
918fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool {
919 match plan {
920 LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name,
921 LogicalPlan::SubqueryAlias(alias) => {
922 subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
923 }
924 _ => {
925 let inputs = plan.inputs();
926 if inputs.len() == 1 {
927 subquery_alias_targets_recursive_cte(inputs[0], cte_name)
928 } else {
929 false
930 }
931 }
932 }
933}
934
935#[cfg(test)]
936mod tests {
937 use std::cmp::Ordering;
938 use std::collections::HashMap;
939 use std::fmt::Formatter;
940 use std::ops::Add;
941 use std::sync::Arc;
942 use std::vec;
943
944 use crate::optimize_projections::OptimizeProjections;
945 use crate::optimizer::Optimizer;
946 use crate::test::{
947 assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields,
948 test_table_scan_with_name,
949 };
950 use crate::{OptimizerContext, OptimizerRule};
951 use arrow::datatypes::{DataType, Field, Schema};
952 use datafusion_common::{
953 Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
954 };
955 use datafusion_expr::ExprFunctionExt;
956 use datafusion_expr::{
957 binary_expr, build_join_schema,
958 builder::table_scan_with_filters,
959 col,
960 expr::{self, Cast},
961 lit,
962 logical_plan::{builder::LogicalPlanBuilder, table_scan},
963 not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator,
964 Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition,
965 };
966 use insta::assert_snapshot;
967
968 use crate::assert_optimized_plan_eq_snapshot;
969 use datafusion_functions_aggregate::count::count_udaf;
970 use datafusion_functions_aggregate::expr_fn::{count, max, min};
971 use datafusion_functions_aggregate::min_max::max_udaf;
972
973 macro_rules! assert_optimized_plan_equal {
974 (
975 $plan:expr,
976 @ $expected:literal $(,)?
977 ) => {{
978 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
979 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(OptimizeProjections::new())];
980 assert_optimized_plan_eq_snapshot!(
981 optimizer_ctx,
982 rules,
983 $plan,
984 @ $expected,
985 )
986 }};
987 }
988
989 #[derive(Debug, Hash, PartialEq, Eq)]
990 struct NoOpUserDefined {
991 exprs: Vec<Expr>,
992 schema: DFSchemaRef,
993 input: Arc<LogicalPlan>,
994 }
995
996 impl NoOpUserDefined {
997 fn new(schema: DFSchemaRef, input: Arc<LogicalPlan>) -> Self {
998 Self {
999 exprs: vec![],
1000 schema,
1001 input,
1002 }
1003 }
1004
1005 fn with_exprs(mut self, exprs: Vec<Expr>) -> Self {
1006 self.exprs = exprs;
1007 self
1008 }
1009 }
1010
1011 impl PartialOrd for NoOpUserDefined {
1013 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1014 match self.exprs.partial_cmp(&other.exprs) {
1015 Some(Ordering::Equal) => self.input.partial_cmp(&other.input),
1016 cmp => cmp,
1017 }
1018 .filter(|cmp| *cmp != Ordering::Equal || self == other)
1020 }
1021 }
1022
1023 impl UserDefinedLogicalNodeCore for NoOpUserDefined {
1024 fn name(&self) -> &str {
1025 "NoOpUserDefined"
1026 }
1027
1028 fn inputs(&self) -> Vec<&LogicalPlan> {
1029 vec![&self.input]
1030 }
1031
1032 fn schema(&self) -> &DFSchemaRef {
1033 &self.schema
1034 }
1035
1036 fn expressions(&self) -> Vec<Expr> {
1037 self.exprs.clone()
1038 }
1039
1040 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1041 write!(f, "NoOpUserDefined")
1042 }
1043
1044 fn with_exprs_and_inputs(
1045 &self,
1046 exprs: Vec<Expr>,
1047 mut inputs: Vec<LogicalPlan>,
1048 ) -> Result<Self> {
1049 Ok(Self {
1050 exprs,
1051 input: Arc::new(inputs.swap_remove(0)),
1052 schema: Arc::clone(&self.schema),
1053 })
1054 }
1055
1056 fn necessary_children_exprs(
1057 &self,
1058 output_columns: &[usize],
1059 ) -> Option<Vec<Vec<usize>>> {
1060 Some(vec![output_columns.to_vec()])
1062 }
1063
1064 fn supports_limit_pushdown(&self) -> bool {
1065 false }
1067 }
1068
1069 #[derive(Debug, Hash, PartialEq, Eq)]
1070 struct UserDefinedCrossJoin {
1071 exprs: Vec<Expr>,
1072 schema: DFSchemaRef,
1073 left_child: Arc<LogicalPlan>,
1074 right_child: Arc<LogicalPlan>,
1075 }
1076
1077 impl UserDefinedCrossJoin {
1078 fn new(left_child: Arc<LogicalPlan>, right_child: Arc<LogicalPlan>) -> Self {
1079 let left_schema = left_child.schema();
1080 let right_schema = right_child.schema();
1081 let schema = Arc::new(
1082 build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(),
1083 );
1084 Self {
1085 exprs: vec![],
1086 schema,
1087 left_child,
1088 right_child,
1089 }
1090 }
1091 }
1092
1093 impl PartialOrd for UserDefinedCrossJoin {
1095 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1096 match self.exprs.partial_cmp(&other.exprs) {
1097 Some(Ordering::Equal) => {
1098 match self.left_child.partial_cmp(&other.left_child) {
1099 Some(Ordering::Equal) => {
1100 self.right_child.partial_cmp(&other.right_child)
1101 }
1102 cmp => cmp,
1103 }
1104 }
1105 cmp => cmp,
1106 }
1107 .filter(|cmp| *cmp != Ordering::Equal || self == other)
1109 }
1110 }
1111
1112 impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin {
1113 fn name(&self) -> &str {
1114 "UserDefinedCrossJoin"
1115 }
1116
1117 fn inputs(&self) -> Vec<&LogicalPlan> {
1118 vec![&self.left_child, &self.right_child]
1119 }
1120
1121 fn schema(&self) -> &DFSchemaRef {
1122 &self.schema
1123 }
1124
1125 fn expressions(&self) -> Vec<Expr> {
1126 self.exprs.clone()
1127 }
1128
1129 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1130 write!(f, "UserDefinedCrossJoin")
1131 }
1132
1133 fn with_exprs_and_inputs(
1134 &self,
1135 exprs: Vec<Expr>,
1136 mut inputs: Vec<LogicalPlan>,
1137 ) -> Result<Self> {
1138 assert_eq!(inputs.len(), 2);
1139 Ok(Self {
1140 exprs,
1141 left_child: Arc::new(inputs.remove(0)),
1142 right_child: Arc::new(inputs.remove(0)),
1143 schema: Arc::clone(&self.schema),
1144 })
1145 }
1146
1147 fn necessary_children_exprs(
1148 &self,
1149 output_columns: &[usize],
1150 ) -> Option<Vec<Vec<usize>>> {
1151 let left_child_len = self.left_child.schema().fields().len();
1152 let mut left_reqs = vec![];
1153 let mut right_reqs = vec![];
1154 for &out_idx in output_columns {
1155 if out_idx < left_child_len {
1156 left_reqs.push(out_idx);
1157 } else {
1158 right_reqs.push(out_idx - left_child_len)
1161 }
1162 }
1163 Some(vec![left_reqs, right_reqs])
1164 }
1165
1166 fn supports_limit_pushdown(&self) -> bool {
1167 false }
1169 }
1170
1171 #[test]
1172 fn merge_two_projection() -> Result<()> {
1173 let table_scan = test_table_scan()?;
1174 let plan = LogicalPlanBuilder::from(table_scan)
1175 .project(vec![col("a")])?
1176 .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1177 .build()?;
1178
1179 assert_optimized_plan_equal!(
1180 plan,
1181 @r"
1182 Projection: Int32(1) + test.a
1183 TableScan: test projection=[a]
1184 "
1185 )
1186 }
1187
1188 #[test]
1189 fn merge_three_projection() -> Result<()> {
1190 let table_scan = test_table_scan()?;
1191 let plan = LogicalPlanBuilder::from(table_scan)
1192 .project(vec![col("a"), col("b")])?
1193 .project(vec![col("a")])?
1194 .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])?
1195 .build()?;
1196
1197 assert_optimized_plan_equal!(
1198 plan,
1199 @r"
1200 Projection: Int32(1) + test.a
1201 TableScan: test projection=[a]
1202 "
1203 )
1204 }
1205
1206 #[test]
1207 fn merge_alias() -> Result<()> {
1208 let table_scan = test_table_scan()?;
1209 let plan = LogicalPlanBuilder::from(table_scan)
1210 .project(vec![col("a")])?
1211 .project(vec![col("a").alias("alias")])?
1212 .build()?;
1213
1214 assert_optimized_plan_equal!(
1215 plan,
1216 @r"
1217 Projection: test.a AS alias
1218 TableScan: test projection=[a]
1219 "
1220 )
1221 }
1222
1223 #[test]
1224 fn merge_nested_alias() -> Result<()> {
1225 let table_scan = test_table_scan()?;
1226 let plan = LogicalPlanBuilder::from(table_scan)
1227 .project(vec![col("a").alias("alias1").alias("alias2")])?
1228 .project(vec![col("alias2").alias("alias")])?
1229 .build()?;
1230
1231 assert_optimized_plan_equal!(
1232 plan,
1233 @r"
1234 Projection: test.a AS alias
1235 TableScan: test projection=[a]
1236 "
1237 )
1238 }
1239
1240 #[test]
1241 fn test_nested_count() -> Result<()> {
1242 let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]);
1243
1244 let groups: Vec<Expr> = vec![];
1245
1246 let plan = table_scan(TableReference::none(), &schema, None)
1247 .unwrap()
1248 .aggregate(groups.clone(), vec![count(lit(1))])
1249 .unwrap()
1250 .aggregate(groups, vec![count(lit(1))])
1251 .unwrap()
1252 .build()
1253 .unwrap();
1254
1255 assert_optimized_plan_equal!(
1256 plan,
1257 @r"
1258 Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]
1259 EmptyRelation: rows=1
1260 "
1261 )
1262 }
1263
1264 #[test]
1265 fn test_neg_push_down() -> Result<()> {
1266 let table_scan = test_table_scan()?;
1267 let plan = LogicalPlanBuilder::from(table_scan)
1268 .project(vec![-col("a")])?
1269 .build()?;
1270
1271 assert_optimized_plan_equal!(
1272 plan,
1273 @r"
1274 Projection: (- test.a)
1275 TableScan: test projection=[a]
1276 "
1277 )
1278 }
1279
1280 #[test]
1281 fn test_is_null() -> Result<()> {
1282 let table_scan = test_table_scan()?;
1283 let plan = LogicalPlanBuilder::from(table_scan)
1284 .project(vec![col("a").is_null()])?
1285 .build()?;
1286
1287 assert_optimized_plan_equal!(
1288 plan,
1289 @r"
1290 Projection: test.a IS NULL
1291 TableScan: test projection=[a]
1292 "
1293 )
1294 }
1295
1296 #[test]
1297 fn test_is_not_null() -> Result<()> {
1298 let table_scan = test_table_scan()?;
1299 let plan = LogicalPlanBuilder::from(table_scan)
1300 .project(vec![col("a").is_not_null()])?
1301 .build()?;
1302
1303 assert_optimized_plan_equal!(
1304 plan,
1305 @r"
1306 Projection: test.a IS NOT NULL
1307 TableScan: test projection=[a]
1308 "
1309 )
1310 }
1311
1312 #[test]
1313 fn test_is_true() -> Result<()> {
1314 let table_scan = test_table_scan()?;
1315 let plan = LogicalPlanBuilder::from(table_scan)
1316 .project(vec![col("a").is_true()])?
1317 .build()?;
1318
1319 assert_optimized_plan_equal!(
1320 plan,
1321 @r"
1322 Projection: test.a IS TRUE
1323 TableScan: test projection=[a]
1324 "
1325 )
1326 }
1327
1328 #[test]
1329 fn test_is_not_true() -> Result<()> {
1330 let table_scan = test_table_scan()?;
1331 let plan = LogicalPlanBuilder::from(table_scan)
1332 .project(vec![col("a").is_not_true()])?
1333 .build()?;
1334
1335 assert_optimized_plan_equal!(
1336 plan,
1337 @r"
1338 Projection: test.a IS NOT TRUE
1339 TableScan: test projection=[a]
1340 "
1341 )
1342 }
1343
1344 #[test]
1345 fn test_is_false() -> Result<()> {
1346 let table_scan = test_table_scan()?;
1347 let plan = LogicalPlanBuilder::from(table_scan)
1348 .project(vec![col("a").is_false()])?
1349 .build()?;
1350
1351 assert_optimized_plan_equal!(
1352 plan,
1353 @r"
1354 Projection: test.a IS FALSE
1355 TableScan: test projection=[a]
1356 "
1357 )
1358 }
1359
1360 #[test]
1361 fn test_is_not_false() -> Result<()> {
1362 let table_scan = test_table_scan()?;
1363 let plan = LogicalPlanBuilder::from(table_scan)
1364 .project(vec![col("a").is_not_false()])?
1365 .build()?;
1366
1367 assert_optimized_plan_equal!(
1368 plan,
1369 @r"
1370 Projection: test.a IS NOT FALSE
1371 TableScan: test projection=[a]
1372 "
1373 )
1374 }
1375
1376 #[test]
1377 fn test_is_unknown() -> Result<()> {
1378 let table_scan = test_table_scan()?;
1379 let plan = LogicalPlanBuilder::from(table_scan)
1380 .project(vec![col("a").is_unknown()])?
1381 .build()?;
1382
1383 assert_optimized_plan_equal!(
1384 plan,
1385 @r"
1386 Projection: test.a IS UNKNOWN
1387 TableScan: test projection=[a]
1388 "
1389 )
1390 }
1391
1392 #[test]
1393 fn test_is_not_unknown() -> Result<()> {
1394 let table_scan = test_table_scan()?;
1395 let plan = LogicalPlanBuilder::from(table_scan)
1396 .project(vec![col("a").is_not_unknown()])?
1397 .build()?;
1398
1399 assert_optimized_plan_equal!(
1400 plan,
1401 @r"
1402 Projection: test.a IS NOT UNKNOWN
1403 TableScan: test projection=[a]
1404 "
1405 )
1406 }
1407
1408 #[test]
1409 fn test_not() -> Result<()> {
1410 let table_scan = test_table_scan()?;
1411 let plan = LogicalPlanBuilder::from(table_scan)
1412 .project(vec![not(col("a"))])?
1413 .build()?;
1414
1415 assert_optimized_plan_equal!(
1416 plan,
1417 @r"
1418 Projection: NOT test.a
1419 TableScan: test projection=[a]
1420 "
1421 )
1422 }
1423
1424 #[test]
1425 fn test_try_cast() -> Result<()> {
1426 let table_scan = test_table_scan()?;
1427 let plan = LogicalPlanBuilder::from(table_scan)
1428 .project(vec![try_cast(col("a"), DataType::Float64)])?
1429 .build()?;
1430
1431 assert_optimized_plan_equal!(
1432 plan,
1433 @r"
1434 Projection: TRY_CAST(test.a AS Float64)
1435 TableScan: test projection=[a]
1436 "
1437 )
1438 }
1439
1440 #[test]
1441 fn test_similar_to() -> Result<()> {
1442 let table_scan = test_table_scan()?;
1443 let expr = Box::new(col("a"));
1444 let pattern = Box::new(lit("[0-9]"));
1445 let similar_to_expr =
1446 Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1447 let plan = LogicalPlanBuilder::from(table_scan)
1448 .project(vec![similar_to_expr])?
1449 .build()?;
1450
1451 assert_optimized_plan_equal!(
1452 plan,
1453 @r#"
1454 Projection: test.a SIMILAR TO Utf8("[0-9]")
1455 TableScan: test projection=[a]
1456 "#
1457 )
1458 }
1459
1460 #[test]
1461 fn test_between() -> Result<()> {
1462 let table_scan = test_table_scan()?;
1463 let plan = LogicalPlanBuilder::from(table_scan)
1464 .project(vec![col("a").between(lit(1), lit(3))])?
1465 .build()?;
1466
1467 assert_optimized_plan_equal!(
1468 plan,
1469 @r"
1470 Projection: test.a BETWEEN Int32(1) AND Int32(3)
1471 TableScan: test projection=[a]
1472 "
1473 )
1474 }
1475
1476 #[test]
1478 fn test_case_merged() -> Result<()> {
1479 let table_scan = test_table_scan()?;
1480 let plan = LogicalPlanBuilder::from(table_scan)
1481 .project(vec![col("a"), lit(0).alias("d")])?
1482 .project(vec![
1483 col("a"),
1484 when(col("a").eq(lit(1)), lit(10))
1485 .otherwise(col("d"))?
1486 .alias("d"),
1487 ])?
1488 .build()?;
1489
1490 assert_optimized_plan_equal!(
1491 plan,
1492 @r"
1493 Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d
1494 TableScan: test projection=[a]
1495 "
1496 )
1497 }
1498
1499 #[test]
1502 fn test_derived_column() -> Result<()> {
1503 let table_scan = test_table_scan()?;
1504 let plan = LogicalPlanBuilder::from(table_scan)
1505 .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])?
1506 .project(vec![
1507 col("a"),
1508 when(col("a").eq(lit(1)), lit(10))
1509 .otherwise(col("d"))?
1510 .alias("d"),
1511 ])?
1512 .build()?;
1513
1514 assert_optimized_plan_equal!(
1515 plan,
1516 @r"
1517 Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d
1518 Projection: test.a + Int32(1) AS a, Int32(0) AS d
1519 TableScan: test projection=[a]
1520 "
1521 )
1522 }
1523
1524 #[test]
1527 fn test_user_defined_logical_plan_node() -> Result<()> {
1528 let table_scan = test_table_scan()?;
1529 let custom_plan = LogicalPlan::Extension(Extension {
1530 node: Arc::new(NoOpUserDefined::new(
1531 Arc::clone(table_scan.schema()),
1532 Arc::new(table_scan.clone()),
1533 )),
1534 });
1535 let plan = LogicalPlanBuilder::from(custom_plan)
1536 .project(vec![col("a"), lit(0).alias("d")])?
1537 .build()?;
1538
1539 assert_optimized_plan_equal!(
1540 plan,
1541 @r"
1542 Projection: test.a, Int32(0) AS d
1543 NoOpUserDefined
1544 TableScan: test projection=[a]
1545 "
1546 )
1547 }
1548
1549 #[test]
1554 fn test_user_defined_logical_plan_node2() -> Result<()> {
1555 let table_scan = test_table_scan()?;
1556 let exprs = vec![Expr::Column(Column::from_qualified_name("b"))];
1557 let custom_plan = LogicalPlan::Extension(Extension {
1558 node: Arc::new(
1559 NoOpUserDefined::new(
1560 Arc::clone(table_scan.schema()),
1561 Arc::new(table_scan.clone()),
1562 )
1563 .with_exprs(exprs),
1564 ),
1565 });
1566 let plan = LogicalPlanBuilder::from(custom_plan)
1567 .project(vec![col("a"), lit(0).alias("d")])?
1568 .build()?;
1569
1570 assert_optimized_plan_equal!(
1571 plan,
1572 @r"
1573 Projection: test.a, Int32(0) AS d
1574 NoOpUserDefined
1575 TableScan: test projection=[a, b]
1576 "
1577 )
1578 }
1579
1580 #[test]
1586 fn test_user_defined_logical_plan_node3() -> Result<()> {
1587 let table_scan = test_table_scan()?;
1588 let left_expr = Expr::Column(Column::from_qualified_name("b"));
1589 let right_expr = Expr::Column(Column::from_qualified_name("c"));
1590 let binary_expr = Expr::BinaryExpr(BinaryExpr::new(
1591 Box::new(left_expr),
1592 Operator::Plus,
1593 Box::new(right_expr),
1594 ));
1595 let exprs = vec![binary_expr];
1596 let custom_plan = LogicalPlan::Extension(Extension {
1597 node: Arc::new(
1598 NoOpUserDefined::new(
1599 Arc::clone(table_scan.schema()),
1600 Arc::new(table_scan.clone()),
1601 )
1602 .with_exprs(exprs),
1603 ),
1604 });
1605 let plan = LogicalPlanBuilder::from(custom_plan)
1606 .project(vec![col("a"), lit(0).alias("d")])?
1607 .build()?;
1608
1609 assert_optimized_plan_equal!(
1610 plan,
1611 @r"
1612 Projection: test.a, Int32(0) AS d
1613 NoOpUserDefined
1614 TableScan: test projection=[a, b, c]
1615 "
1616 )
1617 }
1618
1619 #[test]
1624 fn test_user_defined_logical_plan_node4() -> Result<()> {
1625 let left_table = test_table_scan_with_name("l")?;
1626 let right_table = test_table_scan_with_name("r")?;
1627 let custom_plan = LogicalPlan::Extension(Extension {
1628 node: Arc::new(UserDefinedCrossJoin::new(
1629 Arc::new(left_table),
1630 Arc::new(right_table),
1631 )),
1632 });
1633 let plan = LogicalPlanBuilder::from(custom_plan)
1634 .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])?
1635 .build()?;
1636
1637 assert_optimized_plan_equal!(
1638 plan,
1639 @r"
1640 Projection: l.a, l.c, r.a, Int32(0) AS d
1641 UserDefinedCrossJoin
1642 TableScan: l projection=[a, c]
1643 TableScan: r projection=[a]
1644 "
1645 )
1646 }
1647
1648 #[test]
1649 fn aggregate_no_group_by() -> Result<()> {
1650 let table_scan = test_table_scan()?;
1651
1652 let plan = LogicalPlanBuilder::from(table_scan)
1653 .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1654 .build()?;
1655
1656 assert_optimized_plan_equal!(
1657 plan,
1658 @r"
1659 Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1660 TableScan: test projection=[b]
1661 "
1662 )
1663 }
1664
1665 #[test]
1666 fn aggregate_group_by() -> Result<()> {
1667 let table_scan = test_table_scan()?;
1668
1669 let plan = LogicalPlanBuilder::from(table_scan)
1670 .aggregate(vec![col("c")], vec![max(col("b"))])?
1671 .build()?;
1672
1673 assert_optimized_plan_equal!(
1674 plan,
1675 @r"
1676 Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]
1677 TableScan: test projection=[b, c]
1678 "
1679 )
1680 }
1681
1682 #[test]
1683 fn aggregate_group_by_with_table_alias() -> Result<()> {
1684 let table_scan = test_table_scan()?;
1685
1686 let plan = LogicalPlanBuilder::from(table_scan)
1687 .alias("a")?
1688 .aggregate(vec![col("c")], vec![max(col("b"))])?
1689 .build()?;
1690
1691 assert_optimized_plan_equal!(
1692 plan,
1693 @r"
1694 Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]
1695 SubqueryAlias: a
1696 TableScan: test projection=[b, c]
1697 "
1698 )
1699 }
1700
1701 #[test]
1702 fn aggregate_no_group_by_with_filter() -> Result<()> {
1703 let table_scan = test_table_scan()?;
1704
1705 let plan = LogicalPlanBuilder::from(table_scan)
1706 .filter(col("c").gt(lit(1)))?
1707 .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
1708 .build()?;
1709
1710 assert_optimized_plan_equal!(
1711 plan,
1712 @r"
1713 Aggregate: groupBy=[[]], aggr=[[max(test.b)]]
1714 Projection: test.b
1715 Filter: test.c > Int32(1)
1716 TableScan: test projection=[b, c]
1717 "
1718 )
1719 }
1720
1721 #[test]
1722 fn aggregate_with_periods() -> Result<()> {
1723 let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]);
1724
1725 let plan = table_scan(Some("m4"), &schema, None)?
1732 .aggregate(
1733 Vec::<Expr>::new(),
1734 vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")],
1735 )?
1736 .project([col(Column::new_unqualified("tag.one"))])?
1737 .build()?;
1738
1739 assert_optimized_plan_equal!(
1740 plan,
1741 @r"
1742 Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]
1743 TableScan: m4 projection=[tag.one]
1744 "
1745 )
1746 }
1747
1748 #[test]
1749 fn redundant_project() -> Result<()> {
1750 let table_scan = test_table_scan()?;
1751
1752 let plan = LogicalPlanBuilder::from(table_scan)
1753 .project(vec![col("a"), col("b"), col("c")])?
1754 .project(vec![col("a"), col("c"), col("b")])?
1755 .build()?;
1756 assert_optimized_plan_equal!(
1757 plan,
1758 @r"
1759 Projection: test.a, test.c, test.b
1760 TableScan: test projection=[a, b, c]
1761 "
1762 )
1763 }
1764
1765 #[test]
1766 fn reorder_scan() -> Result<()> {
1767 let schema = Schema::new(test_table_scan_fields());
1768
1769 let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?;
1770 assert_optimized_plan_equal!(
1771 plan,
1772 @"TableScan: test projection=[b, a, c]"
1773 )
1774 }
1775
1776 #[test]
1777 fn reorder_scan_projection() -> Result<()> {
1778 let schema = Schema::new(test_table_scan_fields());
1779
1780 let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?
1781 .project(vec![col("a"), col("b")])?
1782 .build()?;
1783 assert_optimized_plan_equal!(
1784 plan,
1785 @r"
1786 Projection: test.a, test.b
1787 TableScan: test projection=[b, a]
1788 "
1789 )
1790 }
1791
1792 #[test]
1793 fn reorder_projection() -> Result<()> {
1794 let table_scan = test_table_scan()?;
1795
1796 let plan = LogicalPlanBuilder::from(table_scan)
1797 .project(vec![col("c"), col("b"), col("a")])?
1798 .build()?;
1799 assert_optimized_plan_equal!(
1800 plan,
1801 @r"
1802 Projection: test.c, test.b, test.a
1803 TableScan: test projection=[a, b, c]
1804 "
1805 )
1806 }
1807
1808 #[test]
1809 fn noncontinuous_redundant_projection() -> Result<()> {
1810 let table_scan = test_table_scan()?;
1811
1812 let plan = LogicalPlanBuilder::from(table_scan)
1813 .project(vec![col("c"), col("b"), col("a")])?
1814 .filter(col("c").gt(lit(1)))?
1815 .project(vec![col("c"), col("a"), col("b")])?
1816 .filter(col("b").gt(lit(1)))?
1817 .filter(col("a").gt(lit(1)))?
1818 .project(vec![col("a"), col("c"), col("b")])?
1819 .build()?;
1820 assert_optimized_plan_equal!(
1821 plan,
1822 @r"
1823 Projection: test.a, test.c, test.b
1824 Filter: test.a > Int32(1)
1825 Filter: test.b > Int32(1)
1826 Projection: test.c, test.a, test.b
1827 Filter: test.c > Int32(1)
1828 Projection: test.c, test.b, test.a
1829 TableScan: test projection=[a, b, c]
1830 "
1831 )
1832 }
1833
1834 #[test]
1835 fn join_schema_trim_full_join_column_projection() -> Result<()> {
1836 let table_scan = test_table_scan()?;
1837
1838 let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1839 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1840
1841 let plan = LogicalPlanBuilder::from(table_scan)
1842 .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1843 .project(vec![col("a"), col("b"), col("c1")])?
1844 .build()?;
1845
1846 let optimized_plan = optimize(plan)?;
1847
1848 assert_snapshot!(
1850 optimized_plan.clone(),
1851 @r"
1852 Left Join: test.a = test2.c1
1853 TableScan: test projection=[a, b]
1854 TableScan: test2 projection=[c1]
1855 "
1856 );
1857
1858 let optimized_join = optimized_plan;
1860 assert_eq!(
1861 **optimized_join.schema(),
1862 DFSchema::new_with_metadata(
1863 vec![
1864 (
1865 Some("test".into()),
1866 Arc::new(Field::new("a", DataType::UInt32, false))
1867 ),
1868 (
1869 Some("test".into()),
1870 Arc::new(Field::new("b", DataType::UInt32, false))
1871 ),
1872 (
1873 Some("test2".into()),
1874 Arc::new(Field::new("c1", DataType::UInt32, true))
1875 ),
1876 ],
1877 HashMap::new()
1878 )?,
1879 );
1880
1881 Ok(())
1882 }
1883
1884 #[test]
1885 fn join_schema_trim_partial_join_column_projection() -> Result<()> {
1886 let table_scan = test_table_scan()?;
1889
1890 let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
1891 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1892
1893 let plan = LogicalPlanBuilder::from(table_scan)
1894 .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)?
1895 .project(vec![col("a"), col("b")])?
1898 .build()?;
1899
1900 let optimized_plan = optimize(plan)?;
1901
1902 assert_snapshot!(
1904 optimized_plan.clone(),
1905 @r"
1906 Projection: test.a, test.b
1907 Left Join: test.a = test2.c1
1908 TableScan: test projection=[a, b]
1909 TableScan: test2 projection=[c1]
1910 "
1911 );
1912
1913 let optimized_join = optimized_plan.inputs()[0];
1915 assert_eq!(
1916 **optimized_join.schema(),
1917 DFSchema::new_with_metadata(
1918 vec![
1919 (
1920 Some("test".into()),
1921 Arc::new(Field::new("a", DataType::UInt32, false))
1922 ),
1923 (
1924 Some("test".into()),
1925 Arc::new(Field::new("b", DataType::UInt32, false))
1926 ),
1927 (
1928 Some("test2".into()),
1929 Arc::new(Field::new("c1", DataType::UInt32, true))
1930 ),
1931 ],
1932 HashMap::new()
1933 )?,
1934 );
1935
1936 Ok(())
1937 }
1938
1939 #[test]
1940 fn join_schema_trim_using_join() -> Result<()> {
1941 let table_scan = test_table_scan()?;
1944
1945 let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1946 let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?;
1947
1948 let plan = LogicalPlanBuilder::from(table_scan)
1949 .join_using(table2_scan, JoinType::Left, vec!["a".into()])?
1950 .project(vec![col("a"), col("b")])?
1951 .build()?;
1952
1953 let optimized_plan = optimize(plan)?;
1954
1955 assert_snapshot!(
1957 optimized_plan.clone(),
1958 @r"
1959 Projection: test.a, test.b
1960 Left Join: Using test.a = test2.a
1961 TableScan: test projection=[a, b]
1962 TableScan: test2 projection=[a]
1963 "
1964 );
1965
1966 let optimized_join = optimized_plan.inputs()[0];
1968 assert_eq!(
1969 **optimized_join.schema(),
1970 DFSchema::new_with_metadata(
1971 vec![
1972 (
1973 Some("test".into()),
1974 Arc::new(Field::new("a", DataType::UInt32, false))
1975 ),
1976 (
1977 Some("test".into()),
1978 Arc::new(Field::new("b", DataType::UInt32, false))
1979 ),
1980 (
1981 Some("test2".into()),
1982 Arc::new(Field::new("a", DataType::UInt32, true))
1983 ),
1984 ],
1985 HashMap::new()
1986 )?,
1987 );
1988
1989 Ok(())
1990 }
1991
1992 #[test]
1993 fn cast() -> Result<()> {
1994 let table_scan = test_table_scan()?;
1995
1996 let plan = LogicalPlanBuilder::from(table_scan)
1997 .project(vec![Expr::Cast(Cast::new(
1998 Box::new(col("c")),
1999 DataType::Float64,
2000 ))])?
2001 .build()?;
2002
2003 assert_optimized_plan_equal!(
2004 plan,
2005 @r"
2006 Projection: CAST(test.c AS Float64)
2007 TableScan: test projection=[c]
2008 "
2009 )
2010 }
2011
2012 #[test]
2013 fn table_scan_projected_schema() -> Result<()> {
2014 let table_scan = test_table_scan()?;
2015 let plan = LogicalPlanBuilder::from(test_table_scan()?)
2016 .project(vec![col("a"), col("b")])?
2017 .build()?;
2018
2019 assert_eq!(3, table_scan.schema().fields().len());
2020 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2021 assert_fields_eq(&plan, vec!["a", "b"]);
2022
2023 assert_optimized_plan_equal!(
2024 plan,
2025 @"TableScan: test projection=[a, b]"
2026 )
2027 }
2028
2029 #[test]
2030 fn table_scan_projected_schema_non_qualified_relation() -> Result<()> {
2031 let table_scan = test_table_scan()?;
2032 let input_schema = table_scan.schema();
2033 assert_eq!(3, input_schema.fields().len());
2034 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2035
2036 let expr = vec![col("test.a"), col("test.b")];
2040 let plan =
2041 LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?);
2042
2043 assert_fields_eq(&plan, vec!["a", "b"]);
2044
2045 assert_optimized_plan_equal!(
2046 plan,
2047 @"TableScan: test projection=[a, b]"
2048 )
2049 }
2050
2051 #[test]
2052 fn table_limit() -> Result<()> {
2053 let table_scan = test_table_scan()?;
2054 assert_eq!(3, table_scan.schema().fields().len());
2055 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2056
2057 let plan = LogicalPlanBuilder::from(table_scan)
2058 .project(vec![col("c"), col("a")])?
2059 .limit(0, Some(5))?
2060 .build()?;
2061
2062 assert_fields_eq(&plan, vec!["c", "a"]);
2063
2064 assert_optimized_plan_equal!(
2065 plan,
2066 @r"
2067 Limit: skip=0, fetch=5
2068 Projection: test.c, test.a
2069 TableScan: test projection=[a, c]
2070 "
2071 )
2072 }
2073
2074 #[test]
2075 fn table_scan_without_projection() -> Result<()> {
2076 let table_scan = test_table_scan()?;
2077 let plan = LogicalPlanBuilder::from(table_scan).build()?;
2078 assert_optimized_plan_equal!(
2080 plan,
2081 @"TableScan: test projection=[a, b, c]"
2082 )
2083 }
2084
2085 #[test]
2086 fn table_scan_with_literal_projection() -> Result<()> {
2087 let table_scan = test_table_scan()?;
2088 let plan = LogicalPlanBuilder::from(table_scan)
2089 .project(vec![lit(1_i64), lit(2_i64)])?
2090 .build()?;
2091 assert_optimized_plan_equal!(
2092 plan,
2093 @r"
2094 Projection: Int64(1), Int64(2)
2095 TableScan: test projection=[]
2096 "
2097 )
2098 }
2099
2100 #[test]
2102 fn table_unused_column() -> Result<()> {
2103 let table_scan = test_table_scan()?;
2104 assert_eq!(3, table_scan.schema().fields().len());
2105 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2106
2107 let plan = LogicalPlanBuilder::from(table_scan)
2109 .project(vec![col("c"), col("a"), col("b")])?
2110 .filter(col("c").gt(lit(1)))?
2111 .aggregate(vec![col("c")], vec![max(col("a"))])?
2112 .build()?;
2113
2114 assert_fields_eq(&plan, vec!["c", "max(test.a)"]);
2115
2116 let plan = optimize(plan).expect("failed to optimize plan");
2117 assert_optimized_plan_equal!(
2118 plan,
2119 @r"
2120 Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]
2121 Filter: test.c > Int32(1)
2122 Projection: test.c, test.a
2123 TableScan: test projection=[a, c]
2124 "
2125 )
2126 }
2127
2128 #[test]
2130 fn table_unused_projection() -> Result<()> {
2131 let table_scan = test_table_scan()?;
2132 assert_eq!(3, table_scan.schema().fields().len());
2133 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2134
2135 let plan = LogicalPlanBuilder::from(table_scan)
2137 .project(vec![col("b")])?
2138 .project(vec![lit(1).alias("a")])?
2139 .build()?;
2140
2141 assert_fields_eq(&plan, vec!["a"]);
2142
2143 assert_optimized_plan_equal!(
2144 plan,
2145 @r"
2146 Projection: Int32(1) AS a
2147 TableScan: test projection=[]
2148 "
2149 )
2150 }
2151
2152 #[test]
2153 fn table_full_filter_pushdown() -> Result<()> {
2154 let schema = Schema::new(test_table_scan_fields());
2155
2156 let table_scan = table_scan_with_filters(
2157 Some("test"),
2158 &schema,
2159 None,
2160 vec![col("b").eq(lit(1))],
2161 )?
2162 .build()?;
2163 assert_eq!(3, table_scan.schema().fields().len());
2164 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2165
2166 let plan = LogicalPlanBuilder::from(table_scan)
2168 .project(vec![col("b")])?
2169 .project(vec![lit(1).alias("a")])?
2170 .build()?;
2171
2172 assert_fields_eq(&plan, vec!["a"]);
2173
2174 assert_optimized_plan_equal!(
2175 plan,
2176 @r"
2177 Projection: Int32(1) AS a
2178 TableScan: test projection=[], full_filters=[b = Int32(1)]
2179 "
2180 )
2181 }
2182
2183 #[test]
2185 fn test_double_optimization() -> Result<()> {
2186 let table_scan = test_table_scan()?;
2187
2188 let plan = LogicalPlanBuilder::from(table_scan)
2189 .project(vec![col("b")])?
2190 .project(vec![lit(1).alias("a")])?
2191 .build()?;
2192
2193 let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
2194 let optimized_plan2 =
2195 optimize(optimized_plan1.clone()).expect("failed to optimize plan");
2196
2197 let formatted_plan1 = format!("{optimized_plan1:?}");
2198 let formatted_plan2 = format!("{optimized_plan2:?}");
2199 assert_eq!(formatted_plan1, formatted_plan2);
2200 Ok(())
2201 }
2202
2203 #[test]
2205 fn table_unused_aggregate() -> Result<()> {
2206 let table_scan = test_table_scan()?;
2207 assert_eq!(3, table_scan.schema().fields().len());
2208 assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
2209
2210 let plan = LogicalPlanBuilder::from(table_scan)
2212 .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
2213 .filter(col("c").gt(lit(1)))?
2214 .project(vec![col("c"), col("a"), col("max(test.b)")])?
2215 .build()?;
2216
2217 assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]);
2218
2219 assert_optimized_plan_equal!(
2220 plan,
2221 @r"
2222 Projection: test.c, test.a, max(test.b)
2223 Filter: test.c > Int32(1)
2224 Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]
2225 TableScan: test projection=[a, b, c]
2226 "
2227 )
2228 }
2229
2230 #[test]
2231 fn aggregate_filter_pushdown() -> Result<()> {
2232 let table_scan = test_table_scan()?;
2233 let aggr_with_filter = count_udaf()
2234 .call(vec![col("b")])
2235 .filter(col("c").gt(lit(42)))
2236 .build()?;
2237 let plan = LogicalPlanBuilder::from(table_scan)
2238 .aggregate(
2239 vec![col("a")],
2240 vec![count(col("b")), aggr_with_filter.alias("count2")],
2241 )?
2242 .build()?;
2243
2244 assert_optimized_plan_equal!(
2245 plan,
2246 @r"
2247 Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]
2248 TableScan: test projection=[a, b, c]
2249 "
2250 )
2251 }
2252
2253 #[test]
2254 fn pushdown_through_distinct() -> Result<()> {
2255 let table_scan = test_table_scan()?;
2256
2257 let plan = LogicalPlanBuilder::from(table_scan)
2258 .project(vec![col("a"), col("b")])?
2259 .distinct()?
2260 .project(vec![col("a")])?
2261 .build()?;
2262
2263 assert_optimized_plan_equal!(
2264 plan,
2265 @r"
2266 Projection: test.a
2267 Distinct:
2268 TableScan: test projection=[a, b]
2269 "
2270 )
2271 }
2272
2273 #[test]
2274 fn test_window() -> Result<()> {
2275 let table_scan = test_table_scan()?;
2276
2277 let max1 = Expr::from(expr::WindowFunction::new(
2278 WindowFunctionDefinition::AggregateUDF(max_udaf()),
2279 vec![col("test.a")],
2280 ))
2281 .partition_by(vec![col("test.b")])
2282 .build()
2283 .unwrap();
2284
2285 let max2 = Expr::from(expr::WindowFunction::new(
2286 WindowFunctionDefinition::AggregateUDF(max_udaf()),
2287 vec![col("test.b")],
2288 ));
2289 let col1 = col(max1.schema_name().to_string());
2290 let col2 = col(max2.schema_name().to_string());
2291
2292 let plan = LogicalPlanBuilder::from(table_scan)
2293 .window(vec![max1])?
2294 .window(vec![max2])?
2295 .project(vec![col1, col2])?
2296 .build()?;
2297
2298 assert_optimized_plan_equal!(
2299 plan,
2300 @r"
2301 Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2302 WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2303 Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2304 WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
2305 TableScan: test projection=[a, b]
2306 "
2307 )
2308 }
2309
2310 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
2311
2312 fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
2313 let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
2314 let optimized_plan =
2315 optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
2316 Ok(optimized_plan)
2317 }
2318}