1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31 internal_err, plan_err, qualified_name, Column, DFSchema, Result,
32};
33use datafusion_expr::expr::WindowFunction;
34use datafusion_expr::expr_rewriter::replace_col;
35use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
36use datafusion_expr::utils::{
37 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
38};
39use datafusion_expr::{
40 and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
41};
42
43use crate::optimizer::ApplyOrder;
44use crate::simplify_expressions::simplify_predicates;
45use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
46use crate::{OptimizerConfig, OptimizerRule};
47
48#[derive(Default, Debug)]
136pub struct PushDownFilter {}
137
138pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
163 match join_type {
164 JoinType::Inner => (true, true),
165 JoinType::Left => (true, false),
166 JoinType::Right => (false, true),
167 JoinType::Full => (false, false),
168 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
171 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
174 }
175}
176
177pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
187 match join_type {
188 JoinType::Inner => (true, true),
189 JoinType::Left => (false, true),
190 JoinType::Right => (true, false),
191 JoinType::Full => (false, false),
192 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
193 JoinType::LeftAnti => (false, true),
194 JoinType::RightAnti => (true, false),
195 JoinType::LeftMark => (false, true),
196 JoinType::RightMark => (true, false),
197 }
198}
199
200#[derive(Debug)]
203struct ColumnChecker<'a> {
204 left_schema: &'a DFSchema,
206 left_columns: Option<HashSet<Column>>,
208 right_schema: &'a DFSchema,
210 right_columns: Option<HashSet<Column>>,
212}
213
214impl<'a> ColumnChecker<'a> {
215 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
216 Self {
217 left_schema,
218 left_columns: None,
219 right_schema,
220 right_columns: None,
221 }
222 }
223
224 fn is_left_only(&mut self, predicate: &Expr) -> bool {
226 if self.left_columns.is_none() {
227 self.left_columns = Some(schema_columns(self.left_schema));
228 }
229 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
230 }
231
232 fn is_right_only(&mut self, predicate: &Expr) -> bool {
234 if self.right_columns.is_none() {
235 self.right_columns = Some(schema_columns(self.right_schema));
236 }
237 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
238 }
239}
240
241fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
243 schema
244 .iter()
245 .flat_map(|(qualifier, field)| {
246 [
247 Column::new(qualifier.cloned(), field.name()),
248 Column::new_unqualified(field.name()),
250 ]
251 })
252 .collect::<HashSet<_>>()
253}
254
255fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
257 let mut is_evaluate = true;
258 predicate.apply(|expr| match expr {
259 Expr::Column(_)
260 | Expr::Literal(_, _)
261 | Expr::Placeholder(_)
262 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
263 Expr::Exists { .. }
264 | Expr::InSubquery(_)
265 | Expr::ScalarSubquery(_)
266 | Expr::OuterReferenceColumn(_, _)
267 | Expr::Unnest(_) => {
268 is_evaluate = false;
269 Ok(TreeNodeRecursion::Stop)
270 }
271 Expr::Alias(_)
272 | Expr::BinaryExpr(_)
273 | Expr::Like(_)
274 | Expr::SimilarTo(_)
275 | Expr::Not(_)
276 | Expr::IsNotNull(_)
277 | Expr::IsNull(_)
278 | Expr::IsTrue(_)
279 | Expr::IsFalse(_)
280 | Expr::IsUnknown(_)
281 | Expr::IsNotTrue(_)
282 | Expr::IsNotFalse(_)
283 | Expr::IsNotUnknown(_)
284 | Expr::Negative(_)
285 | Expr::Between(_)
286 | Expr::Case(_)
287 | Expr::Cast(_)
288 | Expr::TryCast(_)
289 | Expr::InList { .. }
290 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
291 #[expect(deprecated)]
293 Expr::AggregateFunction(_)
294 | Expr::WindowFunction(_)
295 | Expr::Wildcard { .. }
296 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
297 })?;
298 Ok(is_evaluate)
299}
300
301fn extract_or_clauses_for_join<'a>(
335 filters: &'a [Expr],
336 schema: &'a DFSchema,
337) -> impl Iterator<Item = Expr> + 'a {
338 let schema_columns = schema_columns(schema);
339
340 filters.iter().filter_map(move |expr| {
342 if let Expr::BinaryExpr(BinaryExpr {
343 left,
344 op: Operator::Or,
345 right,
346 }) = expr
347 {
348 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
349 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
350
351 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
353 return Some(or(left_expr, right_expr));
354 }
355 }
356 None
357 })
358}
359
360fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
372 let mut predicate = None;
373
374 match expr {
375 Expr::BinaryExpr(BinaryExpr {
376 left: l_expr,
377 op: Operator::Or,
378 right: r_expr,
379 }) => {
380 let l_expr = extract_or_clause(l_expr, schema_columns);
381 let r_expr = extract_or_clause(r_expr, schema_columns);
382
383 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
384 predicate = Some(or(l_expr, r_expr));
385 }
386 }
387 Expr::BinaryExpr(BinaryExpr {
388 left: l_expr,
389 op: Operator::And,
390 right: r_expr,
391 }) => {
392 let l_expr = extract_or_clause(l_expr, schema_columns);
393 let r_expr = extract_or_clause(r_expr, schema_columns);
394
395 match (l_expr, r_expr) {
396 (Some(l_expr), Some(r_expr)) => {
397 predicate = Some(and(l_expr, r_expr));
398 }
399 (Some(l_expr), None) => {
400 predicate = Some(l_expr);
401 }
402 (None, Some(r_expr)) => {
403 predicate = Some(r_expr);
404 }
405 (None, None) => {
406 predicate = None;
407 }
408 }
409 }
410 _ => {
411 if has_all_column_refs(expr, schema_columns) {
412 predicate = Some(expr.clone());
413 }
414 }
415 }
416
417 predicate
418}
419
420fn push_down_all_join(
422 predicates: Vec<Expr>,
423 inferred_join_predicates: Vec<Expr>,
424 mut join: Join,
425 on_filter: Vec<Expr>,
426) -> Result<Transformed<LogicalPlan>> {
427 let is_inner_join = join.join_type == JoinType::Inner;
428 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
430
431 let left_schema = join.left.schema();
436 let right_schema = join.right.schema();
437 let mut left_push = vec![];
438 let mut right_push = vec![];
439 let mut keep_predicates = vec![];
440 let mut join_conditions = vec![];
441 let mut checker = ColumnChecker::new(left_schema, right_schema);
442 for predicate in predicates {
443 if left_preserved && checker.is_left_only(&predicate) {
444 left_push.push(predicate);
445 } else if right_preserved && checker.is_right_only(&predicate) {
446 right_push.push(predicate);
447 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
448 join_conditions.push(predicate);
451 } else {
452 keep_predicates.push(predicate);
453 }
454 }
455
456 for predicate in inferred_join_predicates {
458 if left_preserved && checker.is_left_only(&predicate) {
459 left_push.push(predicate);
460 } else if right_preserved && checker.is_right_only(&predicate) {
461 right_push.push(predicate);
462 }
463 }
464
465 let mut on_filter_join_conditions = vec![];
466 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
467
468 if !on_filter.is_empty() {
469 for on in on_filter {
470 if on_left_preserved && checker.is_left_only(&on) {
471 left_push.push(on)
472 } else if on_right_preserved && checker.is_right_only(&on) {
473 right_push.push(on)
474 } else {
475 on_filter_join_conditions.push(on)
476 }
477 }
478 }
479
480 if left_preserved {
483 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
484 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
485 }
486 if right_preserved {
487 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
488 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
489 }
490
491 if on_left_preserved {
494 left_push.extend(extract_or_clauses_for_join(
495 &on_filter_join_conditions,
496 left_schema,
497 ));
498 }
499 if on_right_preserved {
500 right_push.extend(extract_or_clauses_for_join(
501 &on_filter_join_conditions,
502 right_schema,
503 ));
504 }
505
506 if let Some(predicate) = conjunction(left_push) {
507 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
508 }
509 if let Some(predicate) = conjunction(right_push) {
510 join.right =
511 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
512 }
513
514 join_conditions.extend(on_filter_join_conditions);
516 join.filter = conjunction(join_conditions);
517
518 let plan = LogicalPlan::Join(join);
520 let plan = if let Some(predicate) = conjunction(keep_predicates) {
521 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
522 } else {
523 plan
524 };
525 Ok(Transformed::yes(plan))
526}
527
528fn push_down_join(
529 join: Join,
530 parent_predicate: Option<&Expr>,
531) -> Result<Transformed<LogicalPlan>> {
532 let predicates = parent_predicate
534 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
535
536 let on_filters = join
538 .filter
539 .as_ref()
540 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
541
542 let inferred_join_predicates =
544 infer_join_predicates(&join, &predicates, &on_filters)?;
545
546 if on_filters.is_empty()
547 && predicates.is_empty()
548 && inferred_join_predicates.is_empty()
549 {
550 return Ok(Transformed::no(LogicalPlan::Join(join)));
551 }
552
553 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
554}
555
556fn infer_join_predicates(
566 join: &Join,
567 predicates: &[Expr],
568 on_filters: &[Expr],
569) -> Result<Vec<Expr>> {
570 let join_col_keys = join
572 .on
573 .iter()
574 .filter_map(|(l, r)| {
575 let left_col = l.try_as_col()?;
576 let right_col = r.try_as_col()?;
577 Some((left_col, right_col))
578 })
579 .collect::<Vec<_>>();
580
581 let join_type = join.join_type;
582
583 let mut inferred_predicates = InferredPredicates::new(join_type);
584
585 infer_join_predicates_from_predicates(
586 &join_col_keys,
587 predicates,
588 &mut inferred_predicates,
589 )?;
590
591 infer_join_predicates_from_on_filters(
592 &join_col_keys,
593 join_type,
594 on_filters,
595 &mut inferred_predicates,
596 )?;
597
598 Ok(inferred_predicates.predicates)
599}
600
601struct InferredPredicates {
610 predicates: Vec<Expr>,
611 is_inner_join: bool,
612}
613
614impl InferredPredicates {
615 fn new(join_type: JoinType) -> Self {
616 Self {
617 predicates: vec![],
618 is_inner_join: matches!(join_type, JoinType::Inner),
619 }
620 }
621
622 fn try_build_predicate(
623 &mut self,
624 predicate: Expr,
625 replace_map: &HashMap<&Column, &Column>,
626 ) -> Result<()> {
627 if self.is_inner_join
628 || matches!(
629 is_restrict_null_predicate(
630 predicate.clone(),
631 replace_map.keys().cloned()
632 ),
633 Ok(true)
634 )
635 {
636 self.predicates.push(replace_col(predicate, replace_map)?);
637 }
638
639 Ok(())
640 }
641}
642
643fn infer_join_predicates_from_predicates(
652 join_col_keys: &[(&Column, &Column)],
653 predicates: &[Expr],
654 inferred_predicates: &mut InferredPredicates,
655) -> Result<()> {
656 infer_join_predicates_impl::<true, true>(
657 join_col_keys,
658 predicates,
659 inferred_predicates,
660 )
661}
662
663fn infer_join_predicates_from_on_filters(
675 join_col_keys: &[(&Column, &Column)],
676 join_type: JoinType,
677 on_filters: &[Expr],
678 inferred_predicates: &mut InferredPredicates,
679) -> Result<()> {
680 match join_type {
681 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
682 JoinType::Inner => infer_join_predicates_impl::<true, true>(
683 join_col_keys,
684 on_filters,
685 inferred_predicates,
686 ),
687 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
688 infer_join_predicates_impl::<true, false>(
689 join_col_keys,
690 on_filters,
691 inferred_predicates,
692 )
693 }
694 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
695 infer_join_predicates_impl::<false, true>(
696 join_col_keys,
697 on_filters,
698 inferred_predicates,
699 )
700 }
701 }
702}
703
704fn infer_join_predicates_impl<
720 const ENABLE_LEFT_TO_RIGHT: bool,
721 const ENABLE_RIGHT_TO_LEFT: bool,
722>(
723 join_col_keys: &[(&Column, &Column)],
724 input_predicates: &[Expr],
725 inferred_predicates: &mut InferredPredicates,
726) -> Result<()> {
727 for predicate in input_predicates {
728 let mut join_cols_to_replace = HashMap::new();
729
730 for &col in &predicate.column_refs() {
731 for (l, r) in join_col_keys.iter() {
732 if ENABLE_LEFT_TO_RIGHT && col == *l {
733 join_cols_to_replace.insert(col, *r);
734 break;
735 }
736 if ENABLE_RIGHT_TO_LEFT && col == *r {
737 join_cols_to_replace.insert(col, *l);
738 break;
739 }
740 }
741 }
742 if join_cols_to_replace.is_empty() {
743 continue;
744 }
745
746 inferred_predicates
747 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
748 }
749 Ok(())
750}
751
752impl OptimizerRule for PushDownFilter {
753 fn name(&self) -> &str {
754 "push_down_filter"
755 }
756
757 fn apply_order(&self) -> Option<ApplyOrder> {
758 Some(ApplyOrder::TopDown)
759 }
760
761 fn supports_rewrite(&self) -> bool {
762 true
763 }
764
765 fn rewrite(
766 &self,
767 plan: LogicalPlan,
768 _config: &dyn OptimizerConfig,
769 ) -> Result<Transformed<LogicalPlan>> {
770 if let LogicalPlan::Join(join) = plan {
771 return push_down_join(join, None);
772 };
773
774 let plan_schema = Arc::clone(plan.schema());
775
776 let LogicalPlan::Filter(mut filter) = plan else {
777 return Ok(Transformed::no(plan));
778 };
779
780 let predicate = split_conjunction_owned(filter.predicate.clone());
781 let old_predicate_len = predicate.len();
782 let new_predicates = simplify_predicates(predicate)?;
783 if old_predicate_len != new_predicates.len() {
784 let Some(new_predicate) = conjunction(new_predicates) else {
785 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
788 };
789 filter.predicate = new_predicate;
790 }
791
792 match Arc::unwrap_or_clone(filter.input) {
793 LogicalPlan::Filter(child_filter) => {
794 let parents_predicates = split_conjunction_owned(filter.predicate);
795
796 let child_predicates = split_conjunction_owned(child_filter.predicate);
798 let new_predicates = parents_predicates
799 .into_iter()
800 .chain(child_predicates)
801 .collect::<IndexSet<_>>()
803 .into_iter()
804 .collect::<Vec<_>>();
805
806 let Some(new_predicate) = conjunction(new_predicates) else {
807 return plan_err!("at least one expression exists");
808 };
809 let new_filter = LogicalPlan::Filter(Filter::try_new(
810 new_predicate,
811 child_filter.input,
812 )?);
813 #[allow(clippy::used_underscore_binding)]
814 self.rewrite(new_filter, _config)
815 }
816 LogicalPlan::Repartition(repartition) => {
817 let new_filter =
818 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
819 .map(LogicalPlan::Filter)?;
820 insert_below(LogicalPlan::Repartition(repartition), new_filter)
821 }
822 LogicalPlan::Distinct(distinct) => {
823 let new_filter =
824 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
825 .map(LogicalPlan::Filter)?;
826 insert_below(LogicalPlan::Distinct(distinct), new_filter)
827 }
828 LogicalPlan::Sort(sort) => {
829 let new_filter =
830 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
831 .map(LogicalPlan::Filter)?;
832 insert_below(LogicalPlan::Sort(sort), new_filter)
833 }
834 LogicalPlan::SubqueryAlias(subquery_alias) => {
835 let mut replace_map = HashMap::new();
836 for (i, (qualifier, field)) in
837 subquery_alias.input.schema().iter().enumerate()
838 {
839 let (sub_qualifier, sub_field) =
840 subquery_alias.schema.qualified_field(i);
841 replace_map.insert(
842 qualified_name(sub_qualifier, sub_field.name()),
843 Expr::Column(Column::new(qualifier.cloned(), field.name())),
844 );
845 }
846 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
847
848 let new_filter = LogicalPlan::Filter(Filter::try_new(
849 new_predicate,
850 Arc::clone(&subquery_alias.input),
851 )?);
852 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
853 }
854 LogicalPlan::Projection(projection) => {
855 let predicates = split_conjunction_owned(filter.predicate.clone());
856 let (new_projection, keep_predicate) =
857 rewrite_projection(predicates, projection)?;
858 if new_projection.transformed {
859 match keep_predicate {
860 None => Ok(new_projection),
861 Some(keep_predicate) => new_projection.map_data(|child_plan| {
862 Filter::try_new(keep_predicate, Arc::new(child_plan))
863 .map(LogicalPlan::Filter)
864 }),
865 }
866 } else {
867 filter.input = Arc::new(new_projection.data);
868 Ok(Transformed::no(LogicalPlan::Filter(filter)))
869 }
870 }
871 LogicalPlan::Unnest(mut unnest) => {
872 let predicates = split_conjunction_owned(filter.predicate.clone());
873 let mut non_unnest_predicates = vec![];
874 let mut unnest_predicates = vec![];
875 let mut unnest_struct_columns = vec![];
876
877 for idx in &unnest.struct_type_columns {
878 let (sub_qualifier, field) =
879 unnest.input.schema().qualified_field(*idx);
880 let field_name = field.name().clone();
881
882 if let DataType::Struct(children) = field.data_type() {
883 for child in children {
884 let child_name = child.name().clone();
885 unnest_struct_columns.push(Column::new(
886 sub_qualifier.cloned(),
887 format!("{field_name}.{child_name}"),
888 ));
889 }
890 }
891 }
892
893 for predicate in predicates {
894 let mut accum: HashSet<Column> = HashSet::new();
896 expr_to_columns(&predicate, &mut accum)?;
897
898 let contains_list_columns =
899 unnest.list_type_columns.iter().any(|(_, unnest_list)| {
900 accum.contains(&unnest_list.output_column)
901 });
902 let contains_struct_columns =
903 unnest_struct_columns.iter().any(|c| accum.contains(c));
904
905 if contains_list_columns || contains_struct_columns {
906 unnest_predicates.push(predicate);
907 } else {
908 non_unnest_predicates.push(predicate);
909 }
910 }
911
912 if non_unnest_predicates.is_empty() {
915 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
916 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
917 }
918
919 let unnest_input = std::mem::take(&mut unnest.input);
928
929 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
930 conjunction(non_unnest_predicates).unwrap(), unnest_input,
932 )?);
933
934 let unnest_plan =
938 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
939
940 match conjunction(unnest_predicates) {
941 None => Ok(unnest_plan),
942 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
943 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
944 ))),
945 }
946 }
947 LogicalPlan::Union(ref union) => {
948 let mut inputs = Vec::with_capacity(union.inputs.len());
949 for input in &union.inputs {
950 let mut replace_map = HashMap::new();
951 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
952 let (union_qualifier, union_field) =
953 union.schema.qualified_field(i);
954 replace_map.insert(
955 qualified_name(union_qualifier, union_field.name()),
956 Expr::Column(Column::new(qualifier.cloned(), field.name())),
957 );
958 }
959
960 let push_predicate =
961 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
962 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
963 push_predicate,
964 Arc::clone(input),
965 )?)))
966 }
967 Ok(Transformed::yes(LogicalPlan::Union(Union {
968 inputs,
969 schema: Arc::clone(&plan_schema),
970 })))
971 }
972 LogicalPlan::Aggregate(agg) => {
973 let group_expr_columns = agg
975 .group_expr
976 .iter()
977 .map(|e| {
978 let (relation, name) = e.qualified_name();
979 Column::new(relation, name)
980 })
981 .collect::<HashSet<_>>();
982
983 let predicates = split_conjunction_owned(filter.predicate);
984
985 let mut keep_predicates = vec![];
986 let mut push_predicates = vec![];
987 for expr in predicates {
988 let cols = expr.column_refs();
989 if cols.iter().all(|c| group_expr_columns.contains(c)) {
990 push_predicates.push(expr);
991 } else {
992 keep_predicates.push(expr);
993 }
994 }
995
996 let mut replace_map = HashMap::new();
1000 for expr in &agg.group_expr {
1001 replace_map.insert(expr.schema_name().to_string(), expr.clone());
1002 }
1003 let replaced_push_predicates = push_predicates
1004 .into_iter()
1005 .map(|expr| replace_cols_by_name(expr, &replace_map))
1006 .collect::<Result<Vec<_>>>()?;
1007
1008 let agg_input = Arc::clone(&agg.input);
1009 Transformed::yes(LogicalPlan::Aggregate(agg))
1010 .transform_data(|new_plan| {
1011 if let Some(predicate) = conjunction(replaced_push_predicates) {
1013 let new_filter = make_filter(predicate, agg_input)?;
1014 insert_below(new_plan, new_filter)
1015 } else {
1016 Ok(Transformed::no(new_plan))
1017 }
1018 })?
1019 .map_data(|child_plan| {
1020 if let Some(predicate) = conjunction(keep_predicates) {
1023 make_filter(predicate, Arc::new(child_plan))
1024 } else {
1025 Ok(child_plan)
1026 }
1027 })
1028 }
1029 LogicalPlan::Window(window) => {
1040 let extract_partition_keys = |func: &WindowFunction| {
1046 func.params
1047 .partition_by
1048 .iter()
1049 .map(|c| {
1050 let (relation, name) = c.qualified_name();
1051 Column::new(relation, name)
1052 })
1053 .collect::<HashSet<_>>()
1054 };
1055 let potential_partition_keys = window
1056 .window_expr
1057 .iter()
1058 .map(|e| {
1059 match e {
1060 Expr::WindowFunction(window_func) => {
1061 extract_partition_keys(window_func)
1062 }
1063 Expr::Alias(alias) => {
1064 if let Expr::WindowFunction(window_func) =
1065 alias.expr.as_ref()
1066 {
1067 extract_partition_keys(window_func)
1068 } else {
1069 unreachable!()
1071 }
1072 }
1073 _ => {
1074 unreachable!()
1076 }
1077 }
1078 })
1079 .reduce(|a, b| &a & &b)
1082 .unwrap_or_default();
1083
1084 let predicates = split_conjunction_owned(filter.predicate);
1085 let mut keep_predicates = vec![];
1086 let mut push_predicates = vec![];
1087 for expr in predicates {
1088 let cols = expr.column_refs();
1089 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1090 push_predicates.push(expr);
1091 } else {
1092 keep_predicates.push(expr);
1093 }
1094 }
1095
1096 let window_input = Arc::clone(&window.input);
1105 Transformed::yes(LogicalPlan::Window(window))
1106 .transform_data(|new_plan| {
1107 if let Some(predicate) = conjunction(push_predicates) {
1109 let new_filter = make_filter(predicate, window_input)?;
1110 insert_below(new_plan, new_filter)
1111 } else {
1112 Ok(Transformed::no(new_plan))
1113 }
1114 })?
1115 .map_data(|child_plan| {
1116 if let Some(predicate) = conjunction(keep_predicates) {
1119 make_filter(predicate, Arc::new(child_plan))
1120 } else {
1121 Ok(child_plan)
1122 }
1123 })
1124 }
1125 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1126 LogicalPlan::TableScan(scan) => {
1127 let filter_predicates = split_conjunction(&filter.predicate);
1128
1129 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1130 filter_predicates
1131 .into_iter()
1132 .partition(|pred| pred.is_volatile());
1133
1134 let supported_filters = scan
1136 .source
1137 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1138 if non_volatile_filters.len() != supported_filters.len() {
1139 return internal_err!(
1140 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1141 supported_filters.len(),
1142 non_volatile_filters.len());
1143 }
1144
1145 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1147
1148 let new_scan_filters = zip
1149 .clone()
1150 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1151 .map(|(pred, _)| pred);
1152
1153 let new_scan_filters: Vec<Expr> = scan
1155 .filters
1156 .iter()
1157 .chain(new_scan_filters)
1158 .unique()
1159 .cloned()
1160 .collect();
1161
1162 let new_predicate: Vec<Expr> = zip
1164 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1165 .map(|(pred, _)| pred)
1166 .chain(volatile_filters)
1167 .cloned()
1168 .collect();
1169
1170 let new_scan = LogicalPlan::TableScan(TableScan {
1171 filters: new_scan_filters,
1172 ..scan
1173 });
1174
1175 Transformed::yes(new_scan).transform_data(|new_scan| {
1176 if let Some(predicate) = conjunction(new_predicate) {
1177 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1178 } else {
1179 Ok(Transformed::no(new_scan))
1180 }
1181 })
1182 }
1183 LogicalPlan::Extension(extension_plan) => {
1184 if extension_plan.node.inputs().is_empty() {
1187 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1188 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1189 }
1190 let prevent_cols =
1191 extension_plan.node.prevent_predicate_push_down_columns();
1192
1193 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1197 .iter()
1198 .map(|expr| {
1199 let cols = expr.column_refs();
1200 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1201 Ok(false) } else {
1203 Ok(true) }
1205 })
1206 .collect::<Result<Vec<_>>>()?;
1207
1208 if predicate_push_or_keep.iter().all(|&x| !x) {
1210 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1211 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1212 }
1213
1214 let mut keep_predicates = vec![];
1216 let mut push_predicates = vec![];
1217 for (push, expr) in predicate_push_or_keep
1218 .into_iter()
1219 .zip(split_conjunction_owned(filter.predicate).into_iter())
1220 {
1221 if !push {
1222 keep_predicates.push(expr);
1223 } else {
1224 push_predicates.push(expr);
1225 }
1226 }
1227
1228 let new_children = match conjunction(push_predicates) {
1229 Some(predicate) => extension_plan
1230 .node
1231 .inputs()
1232 .into_iter()
1233 .map(|child| {
1234 Ok(LogicalPlan::Filter(Filter::try_new(
1235 predicate.clone(),
1236 Arc::new(child.clone()),
1237 )?))
1238 })
1239 .collect::<Result<Vec<_>>>()?,
1240 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1241 };
1242 let child_plan = LogicalPlan::Extension(extension_plan);
1244 let new_extension =
1245 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1246
1247 let new_plan = match conjunction(keep_predicates) {
1248 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1249 predicate,
1250 Arc::new(new_extension),
1251 )?),
1252 None => new_extension,
1253 };
1254 Ok(Transformed::yes(new_plan))
1255 }
1256 child => {
1257 filter.input = Arc::new(child);
1258 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1259 }
1260 }
1261 }
1262}
1263
1264fn rewrite_projection(
1292 predicates: Vec<Expr>,
1293 mut projection: Projection,
1294) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1295 let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1299 .schema
1300 .iter()
1301 .zip(projection.expr.iter())
1302 .map(|((qualifier, field), expr)| {
1303 let expr = expr.clone().unalias();
1305
1306 (qualified_name(qualifier, field.name()), expr)
1307 })
1308 .partition(|(_, value)| value.is_volatile());
1309
1310 let mut push_predicates = vec![];
1311 let mut keep_predicates = vec![];
1312 for expr in predicates {
1313 if contain(&expr, &volatile_map) {
1314 keep_predicates.push(expr);
1315 } else {
1316 push_predicates.push(expr);
1317 }
1318 }
1319
1320 match conjunction(push_predicates) {
1321 Some(expr) => {
1322 let new_filter = LogicalPlan::Filter(Filter::try_new(
1325 replace_cols_by_name(expr, &non_volatile_map)?,
1326 std::mem::take(&mut projection.input),
1327 )?);
1328
1329 projection.input = Arc::new(new_filter);
1330
1331 Ok((
1332 Transformed::yes(LogicalPlan::Projection(projection)),
1333 conjunction(keep_predicates),
1334 ))
1335 }
1336 None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1337 }
1338}
1339
1340pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1342 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1343}
1344
1345fn insert_below(
1359 plan: LogicalPlan,
1360 new_child: LogicalPlan,
1361) -> Result<Transformed<LogicalPlan>> {
1362 let mut new_child = Some(new_child);
1363 let transformed_plan = plan.map_children(|_child| {
1364 if let Some(new_child) = new_child.take() {
1365 Ok(Transformed::yes(new_child))
1366 } else {
1367 internal_err!("node had more than one input")
1369 }
1370 })?;
1371
1372 if new_child.is_some() {
1374 return internal_err!("node had no inputs");
1375 }
1376
1377 Ok(transformed_plan)
1378}
1379
1380impl PushDownFilter {
1381 #[allow(missing_docs)]
1382 pub fn new() -> Self {
1383 Self {}
1384 }
1385}
1386
1387pub fn replace_cols_by_name(
1389 e: Expr,
1390 replace_map: &HashMap<String, Expr>,
1391) -> Result<Expr> {
1392 e.transform_up(|expr| {
1393 Ok(if let Expr::Column(c) = &expr {
1394 match replace_map.get(&c.flat_name()) {
1395 Some(new_c) => Transformed::yes(new_c.clone()),
1396 None => Transformed::no(expr),
1397 }
1398 } else {
1399 Transformed::no(expr)
1400 })
1401 })
1402 .data()
1403}
1404
1405fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1407 let mut is_contain = false;
1408 e.apply(|expr| {
1409 Ok(if let Expr::Column(c) = &expr {
1410 match check_map.get(&c.flat_name()) {
1411 Some(_) => {
1412 is_contain = true;
1413 TreeNodeRecursion::Stop
1414 }
1415 None => TreeNodeRecursion::Continue,
1416 }
1417 } else {
1418 TreeNodeRecursion::Continue
1419 })
1420 })
1421 .unwrap();
1422 is_contain
1423}
1424
1425#[cfg(test)]
1426mod tests {
1427 use std::any::Any;
1428 use std::cmp::Ordering;
1429 use std::fmt::{Debug, Formatter};
1430
1431 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1432 use async_trait::async_trait;
1433
1434 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1435 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1436 use datafusion_expr::logical_plan::table_scan;
1437 use datafusion_expr::{
1438 col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1439 LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1440 TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1441 WindowFunctionDefinition,
1442 };
1443
1444 use crate::assert_optimized_plan_eq_snapshot;
1445 use crate::optimizer::Optimizer;
1446 use crate::simplify_expressions::SimplifyExpressions;
1447 use crate::test::*;
1448 use crate::OptimizerContext;
1449 use datafusion_expr::test::function_stub::sum;
1450 use insta::assert_snapshot;
1451
1452 use super::*;
1453
1454 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1455
1456 macro_rules! assert_optimized_plan_equal {
1457 (
1458 $plan:expr,
1459 @ $expected:literal $(,)?
1460 ) => {{
1461 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1462 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1463 assert_optimized_plan_eq_snapshot!(
1464 optimizer_ctx,
1465 rules,
1466 $plan,
1467 @ $expected,
1468 )
1469 }};
1470 }
1471
1472 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1473 (
1474 $plan:expr,
1475 @ $expected:literal $(,)?
1476 ) => {{
1477 let optimizer = Optimizer::with_rules(vec![
1478 Arc::new(SimplifyExpressions::new()),
1479 Arc::new(PushDownFilter::new()),
1480 ]);
1481 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1482 assert_snapshot!(optimized_plan, @ $expected);
1483 Ok::<(), DataFusionError>(())
1484 }};
1485 }
1486
1487 #[test]
1488 fn filter_before_projection() -> Result<()> {
1489 let table_scan = test_table_scan()?;
1490 let plan = LogicalPlanBuilder::from(table_scan)
1491 .project(vec![col("a"), col("b")])?
1492 .filter(col("a").eq(lit(1i64)))?
1493 .build()?;
1494 assert_optimized_plan_equal!(
1496 plan,
1497 @r"
1498 Projection: test.a, test.b
1499 TableScan: test, full_filters=[test.a = Int64(1)]
1500 "
1501 )
1502 }
1503
1504 #[test]
1505 fn filter_after_limit() -> Result<()> {
1506 let table_scan = test_table_scan()?;
1507 let plan = LogicalPlanBuilder::from(table_scan)
1508 .project(vec![col("a"), col("b")])?
1509 .limit(0, Some(10))?
1510 .filter(col("a").eq(lit(1i64)))?
1511 .build()?;
1512 assert_optimized_plan_equal!(
1514 plan,
1515 @r"
1516 Filter: test.a = Int64(1)
1517 Limit: skip=0, fetch=10
1518 Projection: test.a, test.b
1519 TableScan: test
1520 "
1521 )
1522 }
1523
1524 #[test]
1525 fn filter_no_columns() -> Result<()> {
1526 let table_scan = test_table_scan()?;
1527 let plan = LogicalPlanBuilder::from(table_scan)
1528 .filter(lit(0i64).eq(lit(1i64)))?
1529 .build()?;
1530 assert_optimized_plan_equal!(
1531 plan,
1532 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1533 )
1534 }
1535
1536 #[test]
1537 fn filter_jump_2_plans() -> Result<()> {
1538 let table_scan = test_table_scan()?;
1539 let plan = LogicalPlanBuilder::from(table_scan)
1540 .project(vec![col("a"), col("b"), col("c")])?
1541 .project(vec![col("c"), col("b")])?
1542 .filter(col("a").eq(lit(1i64)))?
1543 .build()?;
1544 assert_optimized_plan_equal!(
1546 plan,
1547 @r"
1548 Projection: test.c, test.b
1549 Projection: test.a, test.b, test.c
1550 TableScan: test, full_filters=[test.a = Int64(1)]
1551 "
1552 )
1553 }
1554
1555 #[test]
1556 fn filter_move_agg() -> Result<()> {
1557 let table_scan = test_table_scan()?;
1558 let plan = LogicalPlanBuilder::from(table_scan)
1559 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1560 .filter(col("a").gt(lit(10i64)))?
1561 .build()?;
1562 assert_optimized_plan_equal!(
1564 plan,
1565 @r"
1566 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1567 TableScan: test, full_filters=[test.a > Int64(10)]
1568 "
1569 )
1570 }
1571
1572 #[test]
1574 fn filter_move_agg_special() -> Result<()> {
1575 let schema = Schema::new(vec![
1576 Field::new("$a", DataType::UInt32, false),
1577 Field::new("$b", DataType::UInt32, false),
1578 Field::new("$c", DataType::UInt32, false),
1579 ]);
1580 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1581
1582 let plan = LogicalPlanBuilder::from(table_scan)
1583 .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1584 .filter(col("$a").gt(lit(10i64)))?
1585 .build()?;
1586 assert_optimized_plan_equal!(
1588 plan,
1589 @r"
1590 Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1591 TableScan: test, full_filters=[test.$a > Int64(10)]
1592 "
1593 )
1594 }
1595
1596 #[test]
1597 fn filter_complex_group_by() -> Result<()> {
1598 let table_scan = test_table_scan()?;
1599 let plan = LogicalPlanBuilder::from(table_scan)
1600 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1601 .filter(col("b").gt(lit(10i64)))?
1602 .build()?;
1603 assert_optimized_plan_equal!(
1604 plan,
1605 @r"
1606 Filter: test.b > Int64(10)
1607 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1608 TableScan: test
1609 "
1610 )
1611 }
1612
1613 #[test]
1614 fn push_agg_need_replace_expr() -> Result<()> {
1615 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1616 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1617 .filter(col("test.b + test.a").gt(lit(10i64)))?
1618 .build()?;
1619 assert_optimized_plan_equal!(
1620 plan,
1621 @r"
1622 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1623 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1624 "
1625 )
1626 }
1627
1628 #[test]
1629 fn filter_keep_agg() -> Result<()> {
1630 let table_scan = test_table_scan()?;
1631 let plan = LogicalPlanBuilder::from(table_scan)
1632 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1633 .filter(col("b").gt(lit(10i64)))?
1634 .build()?;
1635 assert_optimized_plan_equal!(
1637 plan,
1638 @r"
1639 Filter: b > Int64(10)
1640 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1641 TableScan: test
1642 "
1643 )
1644 }
1645
1646 #[test]
1648 fn filter_move_window() -> Result<()> {
1649 let table_scan = test_table_scan()?;
1650
1651 let window = Expr::from(WindowFunction::new(
1652 WindowFunctionDefinition::WindowUDF(
1653 datafusion_functions_window::rank::rank_udwf(),
1654 ),
1655 vec![],
1656 ))
1657 .partition_by(vec![col("a"), col("b")])
1658 .order_by(vec![col("c").sort(true, true)])
1659 .build()
1660 .unwrap();
1661
1662 let plan = LogicalPlanBuilder::from(table_scan)
1663 .window(vec![window])?
1664 .filter(col("b").gt(lit(10i64)))?
1665 .build()?;
1666
1667 assert_optimized_plan_equal!(
1668 plan,
1669 @r"
1670 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1671 TableScan: test, full_filters=[test.b > Int64(10)]
1672 "
1673 )
1674 }
1675
1676 #[test]
1678 fn filter_window_special_identifier() -> Result<()> {
1679 let schema = Schema::new(vec![
1680 Field::new("$a", DataType::UInt32, false),
1681 Field::new("$b", DataType::UInt32, false),
1682 Field::new("$c", DataType::UInt32, false),
1683 ]);
1684 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1685
1686 let window = Expr::from(WindowFunction::new(
1687 WindowFunctionDefinition::WindowUDF(
1688 datafusion_functions_window::rank::rank_udwf(),
1689 ),
1690 vec![],
1691 ))
1692 .partition_by(vec![col("$a"), col("$b")])
1693 .order_by(vec![col("$c").sort(true, true)])
1694 .build()
1695 .unwrap();
1696
1697 let plan = LogicalPlanBuilder::from(table_scan)
1698 .window(vec![window])?
1699 .filter(col("$b").gt(lit(10i64)))?
1700 .build()?;
1701
1702 assert_optimized_plan_equal!(
1703 plan,
1704 @r"
1705 WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1706 TableScan: test, full_filters=[test.$b > Int64(10)]
1707 "
1708 )
1709 }
1710
1711 #[test]
1714 fn filter_move_complex_window() -> Result<()> {
1715 let table_scan = test_table_scan()?;
1716
1717 let window = Expr::from(WindowFunction::new(
1718 WindowFunctionDefinition::WindowUDF(
1719 datafusion_functions_window::rank::rank_udwf(),
1720 ),
1721 vec![],
1722 ))
1723 .partition_by(vec![col("a"), col("b")])
1724 .order_by(vec![col("c").sort(true, true)])
1725 .build()
1726 .unwrap();
1727
1728 let plan = LogicalPlanBuilder::from(table_scan)
1729 .window(vec![window])?
1730 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1731 .build()?;
1732
1733 assert_optimized_plan_equal!(
1734 plan,
1735 @r"
1736 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1737 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1738 "
1739 )
1740 }
1741
1742 #[test]
1744 fn filter_move_partial_window() -> Result<()> {
1745 let table_scan = test_table_scan()?;
1746
1747 let window = Expr::from(WindowFunction::new(
1748 WindowFunctionDefinition::WindowUDF(
1749 datafusion_functions_window::rank::rank_udwf(),
1750 ),
1751 vec![],
1752 ))
1753 .partition_by(vec![col("a")])
1754 .order_by(vec![col("c").sort(true, true)])
1755 .build()
1756 .unwrap();
1757
1758 let plan = LogicalPlanBuilder::from(table_scan)
1759 .window(vec![window])?
1760 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1761 .build()?;
1762
1763 assert_optimized_plan_equal!(
1764 plan,
1765 @r"
1766 Filter: test.b = Int64(1)
1767 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1768 TableScan: test, full_filters=[test.a > Int64(10)]
1769 "
1770 )
1771 }
1772
1773 #[test]
1776 fn filter_expression_keep_window() -> Result<()> {
1777 let table_scan = test_table_scan()?;
1778
1779 let window = Expr::from(WindowFunction::new(
1780 WindowFunctionDefinition::WindowUDF(
1781 datafusion_functions_window::rank::rank_udwf(),
1782 ),
1783 vec![],
1784 ))
1785 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1787 .build()
1788 .unwrap();
1789
1790 let plan = LogicalPlanBuilder::from(table_scan)
1791 .window(vec![window])?
1792 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1795 .build()?;
1796
1797 assert_optimized_plan_equal!(
1798 plan,
1799 @r"
1800 Filter: test.a + test.b > Int64(10)
1801 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1802 TableScan: test
1803 "
1804 )
1805 }
1806
1807 #[test]
1809 fn filter_order_keep_window() -> Result<()> {
1810 let table_scan = test_table_scan()?;
1811
1812 let window = Expr::from(WindowFunction::new(
1813 WindowFunctionDefinition::WindowUDF(
1814 datafusion_functions_window::rank::rank_udwf(),
1815 ),
1816 vec![],
1817 ))
1818 .partition_by(vec![col("a")])
1819 .order_by(vec![col("c").sort(true, true)])
1820 .build()
1821 .unwrap();
1822
1823 let plan = LogicalPlanBuilder::from(table_scan)
1824 .window(vec![window])?
1825 .filter(col("c").gt(lit(10i64)))?
1826 .build()?;
1827
1828 assert_optimized_plan_equal!(
1829 plan,
1830 @r"
1831 Filter: test.c > Int64(10)
1832 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1833 TableScan: test
1834 "
1835 )
1836 }
1837
1838 #[test]
1841 fn filter_multiple_windows_common_partitions() -> Result<()> {
1842 let table_scan = test_table_scan()?;
1843
1844 let window1 = Expr::from(WindowFunction::new(
1845 WindowFunctionDefinition::WindowUDF(
1846 datafusion_functions_window::rank::rank_udwf(),
1847 ),
1848 vec![],
1849 ))
1850 .partition_by(vec![col("a")])
1851 .order_by(vec![col("c").sort(true, true)])
1852 .build()
1853 .unwrap();
1854
1855 let window2 = Expr::from(WindowFunction::new(
1856 WindowFunctionDefinition::WindowUDF(
1857 datafusion_functions_window::rank::rank_udwf(),
1858 ),
1859 vec![],
1860 ))
1861 .partition_by(vec![col("b"), col("a")])
1862 .order_by(vec![col("c").sort(true, true)])
1863 .build()
1864 .unwrap();
1865
1866 let plan = LogicalPlanBuilder::from(table_scan)
1867 .window(vec![window1, window2])?
1868 .filter(col("a").gt(lit(10i64)))? .build()?;
1870
1871 assert_optimized_plan_equal!(
1872 plan,
1873 @r"
1874 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1875 TableScan: test, full_filters=[test.a > Int64(10)]
1876 "
1877 )
1878 }
1879
1880 #[test]
1883 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1884 let table_scan = test_table_scan()?;
1885
1886 let window1 = Expr::from(WindowFunction::new(
1887 WindowFunctionDefinition::WindowUDF(
1888 datafusion_functions_window::rank::rank_udwf(),
1889 ),
1890 vec![],
1891 ))
1892 .partition_by(vec![col("a")])
1893 .order_by(vec![col("c").sort(true, true)])
1894 .build()
1895 .unwrap();
1896
1897 let window2 = Expr::from(WindowFunction::new(
1898 WindowFunctionDefinition::WindowUDF(
1899 datafusion_functions_window::rank::rank_udwf(),
1900 ),
1901 vec![],
1902 ))
1903 .partition_by(vec![col("b"), col("a")])
1904 .order_by(vec![col("c").sort(true, true)])
1905 .build()
1906 .unwrap();
1907
1908 let plan = LogicalPlanBuilder::from(table_scan)
1909 .window(vec![window1, window2])?
1910 .filter(col("b").gt(lit(10i64)))? .build()?;
1912
1913 assert_optimized_plan_equal!(
1914 plan,
1915 @r"
1916 Filter: test.b > Int64(10)
1917 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1918 TableScan: test
1919 "
1920 )
1921 }
1922
1923 #[test]
1925 fn alias() -> Result<()> {
1926 let table_scan = test_table_scan()?;
1927 let plan = LogicalPlanBuilder::from(table_scan)
1928 .project(vec![col("a").alias("b"), col("c")])?
1929 .filter(col("b").eq(lit(1i64)))?
1930 .build()?;
1931 assert_optimized_plan_equal!(
1933 plan,
1934 @r"
1935 Projection: test.a AS b, test.c
1936 TableScan: test, full_filters=[test.a = Int64(1)]
1937 "
1938 )
1939 }
1940
1941 fn add(left: Expr, right: Expr) -> Expr {
1942 Expr::BinaryExpr(BinaryExpr::new(
1943 Box::new(left),
1944 Operator::Plus,
1945 Box::new(right),
1946 ))
1947 }
1948
1949 fn multiply(left: Expr, right: Expr) -> Expr {
1950 Expr::BinaryExpr(BinaryExpr::new(
1951 Box::new(left),
1952 Operator::Multiply,
1953 Box::new(right),
1954 ))
1955 }
1956
1957 #[test]
1959 fn complex_expression() -> Result<()> {
1960 let table_scan = test_table_scan()?;
1961 let plan = LogicalPlanBuilder::from(table_scan)
1962 .project(vec![
1963 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1964 col("c"),
1965 ])?
1966 .filter(col("b").eq(lit(1i64)))?
1967 .build()?;
1968
1969 assert_snapshot!(plan,
1971 @r"
1972 Filter: b = Int64(1)
1973 Projection: test.a * Int32(2) + test.c AS b, test.c
1974 TableScan: test
1975 ",
1976 );
1977 assert_optimized_plan_equal!(
1979 plan,
1980 @r"
1981 Projection: test.a * Int32(2) + test.c AS b, test.c
1982 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1983 "
1984 )
1985 }
1986
1987 #[test]
1989 fn complex_plan() -> Result<()> {
1990 let table_scan = test_table_scan()?;
1991 let plan = LogicalPlanBuilder::from(table_scan)
1992 .project(vec![
1993 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1994 col("c"),
1995 ])?
1996 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1998 .filter(col("a").eq(lit(1i64)))?
1999 .build()?;
2000
2001 assert_snapshot!(plan,
2003 @r"
2004 Filter: a = Int64(1)
2005 Projection: b * Int32(3) AS a, test.c
2006 Projection: test.a * Int32(2) + test.c AS b, test.c
2007 TableScan: test
2008 ",
2009 );
2010 assert_optimized_plan_equal!(
2012 plan,
2013 @r"
2014 Projection: b * Int32(3) AS a, test.c
2015 Projection: test.a * Int32(2) + test.c AS b, test.c
2016 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2017 "
2018 )
2019 }
2020
2021 #[derive(Debug, PartialEq, Eq, Hash)]
2022 struct NoopPlan {
2023 input: Vec<LogicalPlan>,
2024 schema: DFSchemaRef,
2025 }
2026
2027 impl PartialOrd for NoopPlan {
2029 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2030 self.input
2031 .partial_cmp(&other.input)
2032 .filter(|cmp| *cmp != Ordering::Equal || self == other)
2034 }
2035 }
2036
2037 impl UserDefinedLogicalNodeCore for NoopPlan {
2038 fn name(&self) -> &str {
2039 "NoopPlan"
2040 }
2041
2042 fn inputs(&self) -> Vec<&LogicalPlan> {
2043 self.input.iter().collect()
2044 }
2045
2046 fn schema(&self) -> &DFSchemaRef {
2047 &self.schema
2048 }
2049
2050 fn expressions(&self) -> Vec<Expr> {
2051 self.input
2052 .iter()
2053 .flat_map(|child| child.expressions())
2054 .collect()
2055 }
2056
2057 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2058 HashSet::from_iter(vec!["c".to_string()])
2059 }
2060
2061 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2062 write!(f, "NoopPlan")
2063 }
2064
2065 fn with_exprs_and_inputs(
2066 &self,
2067 _exprs: Vec<Expr>,
2068 inputs: Vec<LogicalPlan>,
2069 ) -> Result<Self> {
2070 Ok(Self {
2071 input: inputs,
2072 schema: Arc::clone(&self.schema),
2073 })
2074 }
2075
2076 fn supports_limit_pushdown(&self) -> bool {
2077 false }
2079 }
2080
2081 #[test]
2082 fn user_defined_plan() -> Result<()> {
2083 let table_scan = test_table_scan()?;
2084
2085 let custom_plan = LogicalPlan::Extension(Extension {
2086 node: Arc::new(NoopPlan {
2087 input: vec![table_scan.clone()],
2088 schema: Arc::clone(table_scan.schema()),
2089 }),
2090 });
2091 let plan = LogicalPlanBuilder::from(custom_plan)
2092 .filter(col("a").eq(lit(1i64)))?
2093 .build()?;
2094
2095 assert_optimized_plan_equal!(
2097 plan,
2098 @r"
2099 NoopPlan
2100 TableScan: test, full_filters=[test.a = Int64(1)]
2101 "
2102 )?;
2103
2104 let custom_plan = LogicalPlan::Extension(Extension {
2105 node: Arc::new(NoopPlan {
2106 input: vec![table_scan.clone()],
2107 schema: Arc::clone(table_scan.schema()),
2108 }),
2109 });
2110 let plan = LogicalPlanBuilder::from(custom_plan)
2111 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2112 .build()?;
2113
2114 assert_optimized_plan_equal!(
2116 plan,
2117 @r"
2118 Filter: test.c = Int64(2)
2119 NoopPlan
2120 TableScan: test, full_filters=[test.a = Int64(1)]
2121 "
2122 )?;
2123
2124 let custom_plan = LogicalPlan::Extension(Extension {
2125 node: Arc::new(NoopPlan {
2126 input: vec![table_scan.clone(), table_scan.clone()],
2127 schema: Arc::clone(table_scan.schema()),
2128 }),
2129 });
2130 let plan = LogicalPlanBuilder::from(custom_plan)
2131 .filter(col("a").eq(lit(1i64)))?
2132 .build()?;
2133
2134 assert_optimized_plan_equal!(
2136 plan,
2137 @r"
2138 NoopPlan
2139 TableScan: test, full_filters=[test.a = Int64(1)]
2140 TableScan: test, full_filters=[test.a = Int64(1)]
2141 "
2142 )?;
2143
2144 let custom_plan = LogicalPlan::Extension(Extension {
2145 node: Arc::new(NoopPlan {
2146 input: vec![table_scan.clone(), table_scan.clone()],
2147 schema: Arc::clone(table_scan.schema()),
2148 }),
2149 });
2150 let plan = LogicalPlanBuilder::from(custom_plan)
2151 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2152 .build()?;
2153
2154 assert_optimized_plan_equal!(
2156 plan,
2157 @r"
2158 Filter: test.c = Int64(2)
2159 NoopPlan
2160 TableScan: test, full_filters=[test.a = Int64(1)]
2161 TableScan: test, full_filters=[test.a = Int64(1)]
2162 "
2163 )
2164 }
2165
2166 #[test]
2169 fn multi_filter() -> Result<()> {
2170 let table_scan = test_table_scan()?;
2172 let plan = LogicalPlanBuilder::from(table_scan)
2173 .project(vec![col("a").alias("b"), col("c")])?
2174 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2175 .filter(col("b").gt(lit(10i64)))?
2176 .filter(col("sum(test.c)").gt(lit(10i64)))?
2177 .build()?;
2178
2179 assert_snapshot!(plan,
2181 @r"
2182 Filter: sum(test.c) > Int64(10)
2183 Filter: b > Int64(10)
2184 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2185 Projection: test.a AS b, test.c
2186 TableScan: test
2187 ",
2188 );
2189 assert_optimized_plan_equal!(
2191 plan,
2192 @r"
2193 Filter: sum(test.c) > Int64(10)
2194 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2195 Projection: test.a AS b, test.c
2196 TableScan: test, full_filters=[test.a > Int64(10)]
2197 "
2198 )
2199 }
2200
2201 #[test]
2204 fn split_filter() -> Result<()> {
2205 let table_scan = test_table_scan()?;
2207 let plan = LogicalPlanBuilder::from(table_scan)
2208 .project(vec![col("a").alias("b"), col("c")])?
2209 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2210 .filter(and(
2211 col("sum(test.c)").gt(lit(10i64)),
2212 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2213 ))?
2214 .build()?;
2215
2216 assert_snapshot!(plan,
2218 @r"
2219 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2220 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2221 Projection: test.a AS b, test.c
2222 TableScan: test
2223 ",
2224 );
2225 assert_optimized_plan_equal!(
2227 plan,
2228 @r"
2229 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2230 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2231 Projection: test.a AS b, test.c
2232 TableScan: test, full_filters=[test.a > Int64(10)]
2233 "
2234 )
2235 }
2236
2237 #[test]
2239 fn double_limit() -> Result<()> {
2240 let table_scan = test_table_scan()?;
2241 let plan = LogicalPlanBuilder::from(table_scan)
2242 .project(vec![col("a"), col("b")])?
2243 .limit(0, Some(20))?
2244 .limit(0, Some(10))?
2245 .project(vec![col("a"), col("b")])?
2246 .filter(col("a").eq(lit(1i64)))?
2247 .build()?;
2248 assert_optimized_plan_equal!(
2250 plan,
2251 @r"
2252 Projection: test.a, test.b
2253 Filter: test.a = Int64(1)
2254 Limit: skip=0, fetch=10
2255 Limit: skip=0, fetch=20
2256 Projection: test.a, test.b
2257 TableScan: test
2258 "
2259 )
2260 }
2261
2262 #[test]
2263 fn union_all() -> Result<()> {
2264 let table_scan = test_table_scan()?;
2265 let table_scan2 = test_table_scan_with_name("test2")?;
2266 let plan = LogicalPlanBuilder::from(table_scan)
2267 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2268 .filter(col("a").eq(lit(1i64)))?
2269 .build()?;
2270 assert_optimized_plan_equal!(
2272 plan,
2273 @r"
2274 Union
2275 TableScan: test, full_filters=[test.a = Int64(1)]
2276 TableScan: test2, full_filters=[test2.a = Int64(1)]
2277 "
2278 )
2279 }
2280
2281 #[test]
2282 fn union_all_on_projection() -> Result<()> {
2283 let table_scan = test_table_scan()?;
2284 let table = LogicalPlanBuilder::from(table_scan)
2285 .project(vec![col("a").alias("b")])?
2286 .alias("test2")?;
2287
2288 let plan = table
2289 .clone()
2290 .union(table.build()?)?
2291 .filter(col("b").eq(lit(1i64)))?
2292 .build()?;
2293
2294 assert_optimized_plan_equal!(
2296 plan,
2297 @r"
2298 Union
2299 SubqueryAlias: test2
2300 Projection: test.a AS b
2301 TableScan: test, full_filters=[test.a = Int64(1)]
2302 SubqueryAlias: test2
2303 Projection: test.a AS b
2304 TableScan: test, full_filters=[test.a = Int64(1)]
2305 "
2306 )
2307 }
2308
2309 #[test]
2310 fn test_union_different_schema() -> Result<()> {
2311 let left = LogicalPlanBuilder::from(test_table_scan()?)
2312 .project(vec![col("a"), col("b"), col("c")])?
2313 .build()?;
2314
2315 let schema = Schema::new(vec![
2316 Field::new("d", DataType::UInt32, false),
2317 Field::new("e", DataType::UInt32, false),
2318 Field::new("f", DataType::UInt32, false),
2319 ]);
2320 let right = table_scan(Some("test1"), &schema, None)?
2321 .project(vec![col("d"), col("e"), col("f")])?
2322 .build()?;
2323 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2324 let plan = LogicalPlanBuilder::from(left)
2325 .cross_join(right)?
2326 .project(vec![col("test.a"), col("test1.d")])?
2327 .filter(filter)?
2328 .build()?;
2329
2330 assert_optimized_plan_equal!(
2331 plan,
2332 @r"
2333 Projection: test.a, test1.d
2334 Cross Join:
2335 Projection: test.a, test.b, test.c
2336 TableScan: test, full_filters=[test.a = Int32(1)]
2337 Projection: test1.d, test1.e, test1.f
2338 TableScan: test1, full_filters=[test1.d > Int32(2)]
2339 "
2340 )
2341 }
2342
2343 #[test]
2344 fn test_project_same_name_different_qualifier() -> Result<()> {
2345 let table_scan = test_table_scan()?;
2346 let left = LogicalPlanBuilder::from(table_scan)
2347 .project(vec![col("a"), col("b"), col("c")])?
2348 .build()?;
2349 let right_table_scan = test_table_scan_with_name("test1")?;
2350 let right = LogicalPlanBuilder::from(right_table_scan)
2351 .project(vec![col("a"), col("b"), col("c")])?
2352 .build()?;
2353 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2354 let plan = LogicalPlanBuilder::from(left)
2355 .cross_join(right)?
2356 .project(vec![col("test.a"), col("test1.a")])?
2357 .filter(filter)?
2358 .build()?;
2359
2360 assert_optimized_plan_equal!(
2361 plan,
2362 @r"
2363 Projection: test.a, test1.a
2364 Cross Join:
2365 Projection: test.a, test.b, test.c
2366 TableScan: test, full_filters=[test.a = Int32(1)]
2367 Projection: test1.a, test1.b, test1.c
2368 TableScan: test1, full_filters=[test1.a > Int32(2)]
2369 "
2370 )
2371 }
2372
2373 #[test]
2375 fn filter_2_breaks_limits() -> Result<()> {
2376 let table_scan = test_table_scan()?;
2377 let plan = LogicalPlanBuilder::from(table_scan)
2378 .project(vec![col("a")])?
2379 .filter(col("a").lt_eq(lit(1i64)))?
2380 .limit(0, Some(1))?
2381 .project(vec![col("a")])?
2382 .filter(col("a").gt_eq(lit(1i64)))?
2383 .build()?;
2384 assert_snapshot!(plan,
2388 @r"
2389 Filter: test.a >= Int64(1)
2390 Projection: test.a
2391 Limit: skip=0, fetch=1
2392 Filter: test.a <= Int64(1)
2393 Projection: test.a
2394 TableScan: test
2395 ",
2396 );
2397 assert_optimized_plan_equal!(
2398 plan,
2399 @r"
2400 Projection: test.a
2401 Filter: test.a >= Int64(1)
2402 Limit: skip=0, fetch=1
2403 Projection: test.a
2404 TableScan: test, full_filters=[test.a <= Int64(1)]
2405 "
2406 )
2407 }
2408
2409 #[test]
2411 fn two_filters_on_same_depth() -> Result<()> {
2412 let table_scan = test_table_scan()?;
2413 let plan = LogicalPlanBuilder::from(table_scan)
2414 .limit(0, Some(1))?
2415 .filter(col("a").lt_eq(lit(1i64)))?
2416 .filter(col("a").gt_eq(lit(1i64)))?
2417 .project(vec![col("a")])?
2418 .build()?;
2419
2420 assert_snapshot!(plan,
2422 @r"
2423 Projection: test.a
2424 Filter: test.a >= Int64(1)
2425 Filter: test.a <= Int64(1)
2426 Limit: skip=0, fetch=1
2427 TableScan: test
2428 ",
2429 );
2430 assert_optimized_plan_equal!(
2431 plan,
2432 @r"
2433 Projection: test.a
2434 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2435 Limit: skip=0, fetch=1
2436 TableScan: test
2437 "
2438 )
2439 }
2440
2441 #[test]
2444 fn filters_user_defined_node() -> Result<()> {
2445 let table_scan = test_table_scan()?;
2446 let plan = LogicalPlanBuilder::from(table_scan)
2447 .filter(col("a").lt_eq(lit(1i64)))?
2448 .build()?;
2449
2450 let plan = user_defined::new(plan);
2451
2452 assert_snapshot!(plan,
2454 @r"
2455 TestUserDefined
2456 Filter: test.a <= Int64(1)
2457 TableScan: test
2458 ",
2459 );
2460 assert_optimized_plan_equal!(
2461 plan,
2462 @r"
2463 TestUserDefined
2464 TableScan: test, full_filters=[test.a <= Int64(1)]
2465 "
2466 )
2467 }
2468
2469 #[test]
2471 fn filter_on_join_on_common_independent() -> Result<()> {
2472 let table_scan = test_table_scan()?;
2473 let left = LogicalPlanBuilder::from(table_scan).build()?;
2474 let right_table_scan = test_table_scan_with_name("test2")?;
2475 let right = LogicalPlanBuilder::from(right_table_scan)
2476 .project(vec![col("a")])?
2477 .build()?;
2478 let plan = LogicalPlanBuilder::from(left)
2479 .join(
2480 right,
2481 JoinType::Inner,
2482 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2483 None,
2484 )?
2485 .filter(col("test.a").lt_eq(lit(1i64)))?
2486 .build()?;
2487
2488 assert_snapshot!(plan,
2490 @r"
2491 Filter: test.a <= Int64(1)
2492 Inner Join: test.a = test2.a
2493 TableScan: test
2494 Projection: test2.a
2495 TableScan: test2
2496 ",
2497 );
2498 assert_optimized_plan_equal!(
2500 plan,
2501 @r"
2502 Inner Join: test.a = test2.a
2503 TableScan: test, full_filters=[test.a <= Int64(1)]
2504 Projection: test2.a
2505 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2506 "
2507 )
2508 }
2509
2510 #[test]
2512 fn filter_using_join_on_common_independent() -> Result<()> {
2513 let table_scan = test_table_scan()?;
2514 let left = LogicalPlanBuilder::from(table_scan).build()?;
2515 let right_table_scan = test_table_scan_with_name("test2")?;
2516 let right = LogicalPlanBuilder::from(right_table_scan)
2517 .project(vec![col("a")])?
2518 .build()?;
2519 let plan = LogicalPlanBuilder::from(left)
2520 .join_using(
2521 right,
2522 JoinType::Inner,
2523 vec![Column::from_name("a".to_string())],
2524 )?
2525 .filter(col("a").lt_eq(lit(1i64)))?
2526 .build()?;
2527
2528 assert_snapshot!(plan,
2530 @r"
2531 Filter: test.a <= Int64(1)
2532 Inner Join: Using test.a = test2.a
2533 TableScan: test
2534 Projection: test2.a
2535 TableScan: test2
2536 ",
2537 );
2538 assert_optimized_plan_equal!(
2540 plan,
2541 @r"
2542 Inner Join: Using test.a = test2.a
2543 TableScan: test, full_filters=[test.a <= Int64(1)]
2544 Projection: test2.a
2545 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2546 "
2547 )
2548 }
2549
2550 #[test]
2552 fn filter_join_on_common_dependent() -> Result<()> {
2553 let table_scan = test_table_scan()?;
2554 let left = LogicalPlanBuilder::from(table_scan)
2555 .project(vec![col("a"), col("c")])?
2556 .build()?;
2557 let right_table_scan = test_table_scan_with_name("test2")?;
2558 let right = LogicalPlanBuilder::from(right_table_scan)
2559 .project(vec![col("a"), col("b")])?
2560 .build()?;
2561 let plan = LogicalPlanBuilder::from(left)
2562 .join(
2563 right,
2564 JoinType::Inner,
2565 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2566 None,
2567 )?
2568 .filter(col("c").lt_eq(col("b")))?
2569 .build()?;
2570
2571 assert_snapshot!(plan,
2573 @r"
2574 Filter: test.c <= test2.b
2575 Inner Join: test.a = test2.a
2576 Projection: test.a, test.c
2577 TableScan: test
2578 Projection: test2.a, test2.b
2579 TableScan: test2
2580 ",
2581 );
2582 assert_optimized_plan_equal!(
2584 plan,
2585 @r"
2586 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2587 Projection: test.a, test.c
2588 TableScan: test
2589 Projection: test2.a, test2.b
2590 TableScan: test2
2591 "
2592 )
2593 }
2594
2595 #[test]
2597 fn filter_join_on_one_side() -> Result<()> {
2598 let table_scan = test_table_scan()?;
2599 let left = LogicalPlanBuilder::from(table_scan)
2600 .project(vec![col("a"), col("b")])?
2601 .build()?;
2602 let table_scan_right = test_table_scan_with_name("test2")?;
2603 let right = LogicalPlanBuilder::from(table_scan_right)
2604 .project(vec![col("a"), col("c")])?
2605 .build()?;
2606
2607 let plan = LogicalPlanBuilder::from(left)
2608 .join(
2609 right,
2610 JoinType::Inner,
2611 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2612 None,
2613 )?
2614 .filter(col("b").lt_eq(lit(1i64)))?
2615 .build()?;
2616
2617 assert_snapshot!(plan,
2619 @r"
2620 Filter: test.b <= Int64(1)
2621 Inner Join: test.a = test2.a
2622 Projection: test.a, test.b
2623 TableScan: test
2624 Projection: test2.a, test2.c
2625 TableScan: test2
2626 ",
2627 );
2628 assert_optimized_plan_equal!(
2629 plan,
2630 @r"
2631 Inner Join: test.a = test2.a
2632 Projection: test.a, test.b
2633 TableScan: test, full_filters=[test.b <= Int64(1)]
2634 Projection: test2.a, test2.c
2635 TableScan: test2
2636 "
2637 )
2638 }
2639
2640 #[test]
2643 fn filter_using_left_join() -> Result<()> {
2644 let table_scan = test_table_scan()?;
2645 let left = LogicalPlanBuilder::from(table_scan).build()?;
2646 let right_table_scan = test_table_scan_with_name("test2")?;
2647 let right = LogicalPlanBuilder::from(right_table_scan)
2648 .project(vec![col("a")])?
2649 .build()?;
2650 let plan = LogicalPlanBuilder::from(left)
2651 .join_using(
2652 right,
2653 JoinType::Left,
2654 vec![Column::from_name("a".to_string())],
2655 )?
2656 .filter(col("test2.a").lt_eq(lit(1i64)))?
2657 .build()?;
2658
2659 assert_snapshot!(plan,
2661 @r"
2662 Filter: test2.a <= Int64(1)
2663 Left Join: Using test.a = test2.a
2664 TableScan: test
2665 Projection: test2.a
2666 TableScan: test2
2667 ",
2668 );
2669 assert_optimized_plan_equal!(
2671 plan,
2672 @r"
2673 Filter: test2.a <= Int64(1)
2674 Left Join: Using test.a = test2.a
2675 TableScan: test, full_filters=[test.a <= Int64(1)]
2676 Projection: test2.a
2677 TableScan: test2
2678 "
2679 )
2680 }
2681
2682 #[test]
2684 fn filter_using_right_join() -> Result<()> {
2685 let table_scan = test_table_scan()?;
2686 let left = LogicalPlanBuilder::from(table_scan).build()?;
2687 let right_table_scan = test_table_scan_with_name("test2")?;
2688 let right = LogicalPlanBuilder::from(right_table_scan)
2689 .project(vec![col("a")])?
2690 .build()?;
2691 let plan = LogicalPlanBuilder::from(left)
2692 .join_using(
2693 right,
2694 JoinType::Right,
2695 vec![Column::from_name("a".to_string())],
2696 )?
2697 .filter(col("test.a").lt_eq(lit(1i64)))?
2698 .build()?;
2699
2700 assert_snapshot!(plan,
2702 @r"
2703 Filter: test.a <= Int64(1)
2704 Right Join: Using test.a = test2.a
2705 TableScan: test
2706 Projection: test2.a
2707 TableScan: test2
2708 ",
2709 );
2710 assert_optimized_plan_equal!(
2712 plan,
2713 @r"
2714 Filter: test.a <= Int64(1)
2715 Right Join: Using test.a = test2.a
2716 TableScan: test
2717 Projection: test2.a
2718 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2719 "
2720 )
2721 }
2722
2723 #[test]
2726 fn filter_using_left_join_on_common() -> Result<()> {
2727 let table_scan = test_table_scan()?;
2728 let left = LogicalPlanBuilder::from(table_scan).build()?;
2729 let right_table_scan = test_table_scan_with_name("test2")?;
2730 let right = LogicalPlanBuilder::from(right_table_scan)
2731 .project(vec![col("a")])?
2732 .build()?;
2733 let plan = LogicalPlanBuilder::from(left)
2734 .join_using(
2735 right,
2736 JoinType::Left,
2737 vec![Column::from_name("a".to_string())],
2738 )?
2739 .filter(col("a").lt_eq(lit(1i64)))?
2740 .build()?;
2741
2742 assert_snapshot!(plan,
2744 @r"
2745 Filter: test.a <= Int64(1)
2746 Left Join: Using test.a = test2.a
2747 TableScan: test
2748 Projection: test2.a
2749 TableScan: test2
2750 ",
2751 );
2752 assert_optimized_plan_equal!(
2754 plan,
2755 @r"
2756 Left Join: Using test.a = test2.a
2757 TableScan: test, full_filters=[test.a <= Int64(1)]
2758 Projection: test2.a
2759 TableScan: test2
2760 "
2761 )
2762 }
2763
2764 #[test]
2767 fn filter_using_right_join_on_common() -> Result<()> {
2768 let table_scan = test_table_scan()?;
2769 let left = LogicalPlanBuilder::from(table_scan).build()?;
2770 let right_table_scan = test_table_scan_with_name("test2")?;
2771 let right = LogicalPlanBuilder::from(right_table_scan)
2772 .project(vec![col("a")])?
2773 .build()?;
2774 let plan = LogicalPlanBuilder::from(left)
2775 .join_using(
2776 right,
2777 JoinType::Right,
2778 vec![Column::from_name("a".to_string())],
2779 )?
2780 .filter(col("test2.a").lt_eq(lit(1i64)))?
2781 .build()?;
2782
2783 assert_snapshot!(plan,
2785 @r"
2786 Filter: test2.a <= Int64(1)
2787 Right Join: Using test.a = test2.a
2788 TableScan: test
2789 Projection: test2.a
2790 TableScan: test2
2791 ",
2792 );
2793 assert_optimized_plan_equal!(
2795 plan,
2796 @r"
2797 Right Join: Using test.a = test2.a
2798 TableScan: test
2799 Projection: test2.a
2800 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2801 "
2802 )
2803 }
2804
2805 #[test]
2807 fn join_on_with_filter() -> Result<()> {
2808 let table_scan = test_table_scan()?;
2809 let left = LogicalPlanBuilder::from(table_scan)
2810 .project(vec![col("a"), col("b"), col("c")])?
2811 .build()?;
2812 let right_table_scan = test_table_scan_with_name("test2")?;
2813 let right = LogicalPlanBuilder::from(right_table_scan)
2814 .project(vec![col("a"), col("b"), col("c")])?
2815 .build()?;
2816 let filter = col("test.c")
2817 .gt(lit(1u32))
2818 .and(col("test.b").lt(col("test2.b")))
2819 .and(col("test2.c").gt(lit(4u32)));
2820 let plan = LogicalPlanBuilder::from(left)
2821 .join(
2822 right,
2823 JoinType::Inner,
2824 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2825 Some(filter),
2826 )?
2827 .build()?;
2828
2829 assert_snapshot!(plan,
2831 @r"
2832 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2833 Projection: test.a, test.b, test.c
2834 TableScan: test
2835 Projection: test2.a, test2.b, test2.c
2836 TableScan: test2
2837 ",
2838 );
2839 assert_optimized_plan_equal!(
2840 plan,
2841 @r"
2842 Inner Join: test.a = test2.a Filter: test.b < test2.b
2843 Projection: test.a, test.b, test.c
2844 TableScan: test, full_filters=[test.c > UInt32(1)]
2845 Projection: test2.a, test2.b, test2.c
2846 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2847 "
2848 )
2849 }
2850
2851 #[test]
2853 fn join_filter_removed() -> Result<()> {
2854 let table_scan = test_table_scan()?;
2855 let left = LogicalPlanBuilder::from(table_scan)
2856 .project(vec![col("a"), col("b"), col("c")])?
2857 .build()?;
2858 let right_table_scan = test_table_scan_with_name("test2")?;
2859 let right = LogicalPlanBuilder::from(right_table_scan)
2860 .project(vec![col("a"), col("b"), col("c")])?
2861 .build()?;
2862 let filter = col("test.b")
2863 .gt(lit(1u32))
2864 .and(col("test2.c").gt(lit(4u32)));
2865 let plan = LogicalPlanBuilder::from(left)
2866 .join(
2867 right,
2868 JoinType::Inner,
2869 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2870 Some(filter),
2871 )?
2872 .build()?;
2873
2874 assert_snapshot!(plan,
2876 @r"
2877 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2878 Projection: test.a, test.b, test.c
2879 TableScan: test
2880 Projection: test2.a, test2.b, test2.c
2881 TableScan: test2
2882 ",
2883 );
2884 assert_optimized_plan_equal!(
2885 plan,
2886 @r"
2887 Inner Join: test.a = test2.a
2888 Projection: test.a, test.b, test.c
2889 TableScan: test, full_filters=[test.b > UInt32(1)]
2890 Projection: test2.a, test2.b, test2.c
2891 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2892 "
2893 )
2894 }
2895
2896 #[test]
2898 fn join_filter_on_common() -> Result<()> {
2899 let table_scan = test_table_scan()?;
2900 let left = LogicalPlanBuilder::from(table_scan)
2901 .project(vec![col("a")])?
2902 .build()?;
2903 let right_table_scan = test_table_scan_with_name("test2")?;
2904 let right = LogicalPlanBuilder::from(right_table_scan)
2905 .project(vec![col("b")])?
2906 .build()?;
2907 let filter = col("test.a").gt(lit(1u32));
2908 let plan = LogicalPlanBuilder::from(left)
2909 .join(
2910 right,
2911 JoinType::Inner,
2912 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2913 Some(filter),
2914 )?
2915 .build()?;
2916
2917 assert_snapshot!(plan,
2919 @r"
2920 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2921 Projection: test.a
2922 TableScan: test
2923 Projection: test2.b
2924 TableScan: test2
2925 ",
2926 );
2927 assert_optimized_plan_equal!(
2928 plan,
2929 @r"
2930 Inner Join: test.a = test2.b
2931 Projection: test.a
2932 TableScan: test, full_filters=[test.a > UInt32(1)]
2933 Projection: test2.b
2934 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2935 "
2936 )
2937 }
2938
2939 #[test]
2941 fn left_join_on_with_filter() -> Result<()> {
2942 let table_scan = test_table_scan()?;
2943 let left = LogicalPlanBuilder::from(table_scan)
2944 .project(vec![col("a"), col("b"), col("c")])?
2945 .build()?;
2946 let right_table_scan = test_table_scan_with_name("test2")?;
2947 let right = LogicalPlanBuilder::from(right_table_scan)
2948 .project(vec![col("a"), col("b"), col("c")])?
2949 .build()?;
2950 let filter = col("test.a")
2951 .gt(lit(1u32))
2952 .and(col("test.b").lt(col("test2.b")))
2953 .and(col("test2.c").gt(lit(4u32)));
2954 let plan = LogicalPlanBuilder::from(left)
2955 .join(
2956 right,
2957 JoinType::Left,
2958 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2959 Some(filter),
2960 )?
2961 .build()?;
2962
2963 assert_snapshot!(plan,
2965 @r"
2966 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2967 Projection: test.a, test.b, test.c
2968 TableScan: test
2969 Projection: test2.a, test2.b, test2.c
2970 TableScan: test2
2971 ",
2972 );
2973 assert_optimized_plan_equal!(
2974 plan,
2975 @r"
2976 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2977 Projection: test.a, test.b, test.c
2978 TableScan: test
2979 Projection: test2.a, test2.b, test2.c
2980 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2981 "
2982 )
2983 }
2984
2985 #[test]
2987 fn right_join_on_with_filter() -> Result<()> {
2988 let table_scan = test_table_scan()?;
2989 let left = LogicalPlanBuilder::from(table_scan)
2990 .project(vec![col("a"), col("b"), col("c")])?
2991 .build()?;
2992 let right_table_scan = test_table_scan_with_name("test2")?;
2993 let right = LogicalPlanBuilder::from(right_table_scan)
2994 .project(vec![col("a"), col("b"), col("c")])?
2995 .build()?;
2996 let filter = col("test.a")
2997 .gt(lit(1u32))
2998 .and(col("test.b").lt(col("test2.b")))
2999 .and(col("test2.c").gt(lit(4u32)));
3000 let plan = LogicalPlanBuilder::from(left)
3001 .join(
3002 right,
3003 JoinType::Right,
3004 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3005 Some(filter),
3006 )?
3007 .build()?;
3008
3009 assert_snapshot!(plan,
3011 @r"
3012 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3013 Projection: test.a, test.b, test.c
3014 TableScan: test
3015 Projection: test2.a, test2.b, test2.c
3016 TableScan: test2
3017 ",
3018 );
3019 assert_optimized_plan_equal!(
3020 plan,
3021 @r"
3022 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3023 Projection: test.a, test.b, test.c
3024 TableScan: test, full_filters=[test.a > UInt32(1)]
3025 Projection: test2.a, test2.b, test2.c
3026 TableScan: test2
3027 "
3028 )
3029 }
3030
3031 #[test]
3033 fn full_join_on_with_filter() -> Result<()> {
3034 let table_scan = test_table_scan()?;
3035 let left = LogicalPlanBuilder::from(table_scan)
3036 .project(vec![col("a"), col("b"), col("c")])?
3037 .build()?;
3038 let right_table_scan = test_table_scan_with_name("test2")?;
3039 let right = LogicalPlanBuilder::from(right_table_scan)
3040 .project(vec![col("a"), col("b"), col("c")])?
3041 .build()?;
3042 let filter = col("test.a")
3043 .gt(lit(1u32))
3044 .and(col("test.b").lt(col("test2.b")))
3045 .and(col("test2.c").gt(lit(4u32)));
3046 let plan = LogicalPlanBuilder::from(left)
3047 .join(
3048 right,
3049 JoinType::Full,
3050 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3051 Some(filter),
3052 )?
3053 .build()?;
3054
3055 assert_snapshot!(plan,
3057 @r"
3058 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3059 Projection: test.a, test.b, test.c
3060 TableScan: test
3061 Projection: test2.a, test2.b, test2.c
3062 TableScan: test2
3063 ",
3064 );
3065 assert_optimized_plan_equal!(
3066 plan,
3067 @r"
3068 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3069 Projection: test.a, test.b, test.c
3070 TableScan: test
3071 Projection: test2.a, test2.b, test2.c
3072 TableScan: test2
3073 "
3074 )
3075 }
3076
3077 struct PushDownProvider {
3078 pub filter_support: TableProviderFilterPushDown,
3079 }
3080
3081 #[async_trait]
3082 impl TableSource for PushDownProvider {
3083 fn schema(&self) -> SchemaRef {
3084 Arc::new(Schema::new(vec![
3085 Field::new("a", DataType::Int32, true),
3086 Field::new("b", DataType::Int32, true),
3087 ]))
3088 }
3089
3090 fn table_type(&self) -> TableType {
3091 TableType::Base
3092 }
3093
3094 fn supports_filters_pushdown(
3095 &self,
3096 filters: &[&Expr],
3097 ) -> Result<Vec<TableProviderFilterPushDown>> {
3098 Ok((0..filters.len())
3099 .map(|_| self.filter_support.clone())
3100 .collect())
3101 }
3102
3103 fn as_any(&self) -> &dyn Any {
3104 self
3105 }
3106 }
3107
3108 fn table_scan_with_pushdown_provider_builder(
3109 filter_support: TableProviderFilterPushDown,
3110 filters: Vec<Expr>,
3111 projection: Option<Vec<usize>>,
3112 ) -> Result<LogicalPlanBuilder> {
3113 let test_provider = PushDownProvider { filter_support };
3114
3115 let table_scan = LogicalPlan::TableScan(TableScan {
3116 table_name: "test".into(),
3117 filters,
3118 projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3119 projection,
3120 source: Arc::new(test_provider),
3121 fetch: None,
3122 });
3123
3124 Ok(LogicalPlanBuilder::from(table_scan))
3125 }
3126
3127 fn table_scan_with_pushdown_provider(
3128 filter_support: TableProviderFilterPushDown,
3129 ) -> Result<LogicalPlan> {
3130 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3131 .filter(col("a").eq(lit(1i64)))?
3132 .build()
3133 }
3134
3135 #[test]
3136 fn filter_with_table_provider_exact() -> Result<()> {
3137 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3138
3139 assert_optimized_plan_equal!(
3140 plan,
3141 @"TableScan: test, full_filters=[a = Int64(1)]"
3142 )
3143 }
3144
3145 #[test]
3146 fn filter_with_table_provider_inexact() -> Result<()> {
3147 let plan =
3148 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3149
3150 assert_optimized_plan_equal!(
3151 plan,
3152 @r"
3153 Filter: a = Int64(1)
3154 TableScan: test, partial_filters=[a = Int64(1)]
3155 "
3156 )
3157 }
3158
3159 #[test]
3160 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3161 let plan =
3162 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3163
3164 let optimized_plan = PushDownFilter::new()
3165 .rewrite(plan, &OptimizerContext::new())
3166 .expect("failed to optimize plan")
3167 .data;
3168
3169 assert_optimized_plan_equal!(
3172 optimized_plan,
3173 @r"
3174 Filter: a = Int64(1)
3175 TableScan: test, partial_filters=[a = Int64(1)]
3176 "
3177 )
3178 }
3179
3180 #[test]
3181 fn filter_with_table_provider_unsupported() -> Result<()> {
3182 let plan =
3183 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3184
3185 assert_optimized_plan_equal!(
3186 plan,
3187 @r"
3188 Filter: a = Int64(1)
3189 TableScan: test
3190 "
3191 )
3192 }
3193
3194 #[test]
3195 fn multi_combined_filter() -> Result<()> {
3196 let plan = table_scan_with_pushdown_provider_builder(
3197 TableProviderFilterPushDown::Inexact,
3198 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3199 Some(vec![0]),
3200 )?
3201 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3202 .project(vec![col("a"), col("b")])?
3203 .build()?;
3204
3205 assert_optimized_plan_equal!(
3206 plan,
3207 @r"
3208 Projection: a, b
3209 Filter: a = Int64(10) AND b > Int64(11)
3210 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3211 "
3212 )
3213 }
3214
3215 #[test]
3216 fn multi_combined_filter_exact() -> Result<()> {
3217 let plan = table_scan_with_pushdown_provider_builder(
3218 TableProviderFilterPushDown::Exact,
3219 vec![],
3220 Some(vec![0]),
3221 )?
3222 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3223 .project(vec![col("a"), col("b")])?
3224 .build()?;
3225
3226 assert_optimized_plan_equal!(
3227 plan,
3228 @r"
3229 Projection: a, b
3230 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3231 "
3232 )
3233 }
3234
3235 #[test]
3236 fn test_filter_with_alias() -> Result<()> {
3237 let table_scan = test_table_scan()?;
3241 let plan = LogicalPlanBuilder::from(table_scan)
3242 .project(vec![col("a").alias("b"), col("c")])?
3243 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3244 .build()?;
3245
3246 assert_snapshot!(plan,
3248 @r"
3249 Filter: b > Int64(10) AND test.c > Int64(10)
3250 Projection: test.a AS b, test.c
3251 TableScan: test
3252 ",
3253 );
3254 assert_optimized_plan_equal!(
3256 plan,
3257 @r"
3258 Projection: test.a AS b, test.c
3259 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3260 "
3261 )
3262 }
3263
3264 #[test]
3265 fn test_filter_with_alias_2() -> Result<()> {
3266 let table_scan = test_table_scan()?;
3270 let plan = LogicalPlanBuilder::from(table_scan)
3271 .project(vec![col("a").alias("b"), col("c")])?
3272 .project(vec![col("b"), col("c")])?
3273 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3274 .build()?;
3275
3276 assert_snapshot!(plan,
3278 @r"
3279 Filter: b > Int64(10) AND test.c > Int64(10)
3280 Projection: b, test.c
3281 Projection: test.a AS b, test.c
3282 TableScan: test
3283 ",
3284 );
3285 assert_optimized_plan_equal!(
3287 plan,
3288 @r"
3289 Projection: b, test.c
3290 Projection: test.a AS b, test.c
3291 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3292 "
3293 )
3294 }
3295
3296 #[test]
3297 fn test_filter_with_multi_alias() -> Result<()> {
3298 let table_scan = test_table_scan()?;
3299 let plan = LogicalPlanBuilder::from(table_scan)
3300 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3301 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3302 .build()?;
3303
3304 assert_snapshot!(plan,
3306 @r"
3307 Filter: b > Int64(10) AND d > Int64(10)
3308 Projection: test.a AS b, test.c AS d
3309 TableScan: test
3310 ",
3311 );
3312 assert_optimized_plan_equal!(
3314 plan,
3315 @r"
3316 Projection: test.a AS b, test.c AS d
3317 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3318 "
3319 )
3320 }
3321
3322 #[test]
3324 fn join_filter_with_alias() -> Result<()> {
3325 let table_scan = test_table_scan()?;
3326 let left = LogicalPlanBuilder::from(table_scan)
3327 .project(vec![col("a").alias("c")])?
3328 .build()?;
3329 let right_table_scan = test_table_scan_with_name("test2")?;
3330 let right = LogicalPlanBuilder::from(right_table_scan)
3331 .project(vec![col("b").alias("d")])?
3332 .build()?;
3333 let filter = col("c").gt(lit(1u32));
3334 let plan = LogicalPlanBuilder::from(left)
3335 .join(
3336 right,
3337 JoinType::Inner,
3338 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3339 Some(filter),
3340 )?
3341 .build()?;
3342
3343 assert_snapshot!(plan,
3344 @r"
3345 Inner Join: c = d Filter: c > UInt32(1)
3346 Projection: test.a AS c
3347 TableScan: test
3348 Projection: test2.b AS d
3349 TableScan: test2
3350 ",
3351 );
3352 assert_optimized_plan_equal!(
3354 plan,
3355 @r"
3356 Inner Join: c = d
3357 Projection: test.a AS c
3358 TableScan: test, full_filters=[test.a > UInt32(1)]
3359 Projection: test2.b AS d
3360 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3361 "
3362 )
3363 }
3364
3365 #[test]
3366 fn test_in_filter_with_alias() -> Result<()> {
3367 let table_scan = test_table_scan()?;
3371 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3372 let plan = LogicalPlanBuilder::from(table_scan)
3373 .project(vec![col("a").alias("b"), col("c")])?
3374 .filter(in_list(col("b"), filter_value, false))?
3375 .build()?;
3376
3377 assert_snapshot!(plan,
3379 @r"
3380 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3381 Projection: test.a AS b, test.c
3382 TableScan: test
3383 ",
3384 );
3385 assert_optimized_plan_equal!(
3387 plan,
3388 @r"
3389 Projection: test.a AS b, test.c
3390 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3391 "
3392 )
3393 }
3394
3395 #[test]
3396 fn test_in_filter_with_alias_2() -> Result<()> {
3397 let table_scan = test_table_scan()?;
3401 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3402 let plan = LogicalPlanBuilder::from(table_scan)
3403 .project(vec![col("a").alias("b"), col("c")])?
3404 .project(vec![col("b"), col("c")])?
3405 .filter(in_list(col("b"), filter_value, false))?
3406 .build()?;
3407
3408 assert_snapshot!(plan,
3410 @r"
3411 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3412 Projection: b, test.c
3413 Projection: test.a AS b, test.c
3414 TableScan: test
3415 ",
3416 );
3417 assert_optimized_plan_equal!(
3419 plan,
3420 @r"
3421 Projection: b, test.c
3422 Projection: test.a AS b, test.c
3423 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3424 "
3425 )
3426 }
3427
3428 #[test]
3429 fn test_in_subquery_with_alias() -> Result<()> {
3430 let table_scan = test_table_scan()?;
3433 let table_scan_sq = test_table_scan_with_name("sq")?;
3434 let subplan = Arc::new(
3435 LogicalPlanBuilder::from(table_scan_sq)
3436 .project(vec![col("c")])?
3437 .build()?,
3438 );
3439 let plan = LogicalPlanBuilder::from(table_scan)
3440 .project(vec![col("a").alias("b"), col("c")])?
3441 .filter(in_subquery(col("b"), subplan))?
3442 .build()?;
3443
3444 assert_snapshot!(plan,
3446 @r"
3447 Filter: b IN (<subquery>)
3448 Subquery:
3449 Projection: sq.c
3450 TableScan: sq
3451 Projection: test.a AS b, test.c
3452 TableScan: test
3453 ",
3454 );
3455 assert_optimized_plan_equal!(
3457 plan,
3458 @r"
3459 Projection: test.a AS b, test.c
3460 TableScan: test, full_filters=[test.a IN (<subquery>)]
3461 Subquery:
3462 Projection: sq.c
3463 TableScan: sq
3464 "
3465 )
3466 }
3467
3468 #[test]
3469 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3470 let plan = LogicalPlanBuilder::empty(true)
3472 .project(vec![lit(0i64).alias("a")])?
3473 .alias("b")?
3474 .project(vec![col("b.a")])?
3475 .alias("b")?
3476 .filter(col("b.a").eq(lit(1i64)))?
3477 .project(vec![col("b.a")])?
3478 .build()?;
3479
3480 assert_snapshot!(plan,
3481 @r"
3482 Projection: b.a
3483 Filter: b.a = Int64(1)
3484 SubqueryAlias: b
3485 Projection: b.a
3486 SubqueryAlias: b
3487 Projection: Int64(0) AS a
3488 EmptyRelation: rows=1
3489 ",
3490 );
3491 assert_optimized_plan_equal!(
3494 plan,
3495 @r"
3496 Projection: b.a
3497 SubqueryAlias: b
3498 Projection: b.a
3499 SubqueryAlias: b
3500 Projection: Int64(0) AS a
3501 Filter: Int64(0) = Int64(1)
3502 EmptyRelation: rows=1
3503 "
3504 )
3505 }
3506
3507 #[test]
3508 fn test_crossjoin_with_or_clause() -> Result<()> {
3509 let table_scan = test_table_scan()?;
3511 let left = LogicalPlanBuilder::from(table_scan)
3512 .project(vec![col("a"), col("b"), col("c")])?
3513 .build()?;
3514 let right_table_scan = test_table_scan_with_name("test1")?;
3515 let right = LogicalPlanBuilder::from(right_table_scan)
3516 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3517 .build()?;
3518 let filter = or(
3519 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3520 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3521 );
3522 let plan = LogicalPlanBuilder::from(left)
3523 .cross_join(right)?
3524 .filter(filter)?
3525 .build()?;
3526
3527 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3528 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3529 Projection: test.a, test.b, test.c
3530 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3531 Projection: test1.a AS d, test1.a AS e
3532 TableScan: test1
3533 ")?;
3534
3535 let optimized_plan = PushDownFilter::new()
3538 .rewrite(plan, &OptimizerContext::new())
3539 .expect("failed to optimize plan")
3540 .data;
3541 assert_optimized_plan_equal!(
3542 optimized_plan,
3543 @r"
3544 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3545 Projection: test.a, test.b, test.c
3546 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3547 Projection: test1.a AS d, test1.a AS e
3548 TableScan: test1
3549 "
3550 )
3551 }
3552
3553 #[test]
3554 fn left_semi_join() -> Result<()> {
3555 let left = test_table_scan_with_name("test1")?;
3556 let right_table_scan = test_table_scan_with_name("test2")?;
3557 let right = LogicalPlanBuilder::from(right_table_scan)
3558 .project(vec![col("a"), col("b")])?
3559 .build()?;
3560 let plan = LogicalPlanBuilder::from(left)
3561 .join(
3562 right,
3563 JoinType::LeftSemi,
3564 (
3565 vec![Column::from_qualified_name("test1.a")],
3566 vec![Column::from_qualified_name("test2.a")],
3567 ),
3568 None,
3569 )?
3570 .filter(col("test2.a").lt_eq(lit(1i64)))?
3571 .build()?;
3572
3573 assert_snapshot!(plan,
3575 @r"
3576 Filter: test2.a <= Int64(1)
3577 LeftSemi Join: test1.a = test2.a
3578 TableScan: test1
3579 Projection: test2.a, test2.b
3580 TableScan: test2
3581 ",
3582 );
3583 assert_optimized_plan_equal!(
3585 plan,
3586 @r"
3587 Filter: test2.a <= Int64(1)
3588 LeftSemi Join: test1.a = test2.a
3589 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3590 Projection: test2.a, test2.b
3591 TableScan: test2
3592 "
3593 )
3594 }
3595
3596 #[test]
3597 fn left_semi_join_with_filters() -> Result<()> {
3598 let left = test_table_scan_with_name("test1")?;
3599 let right_table_scan = test_table_scan_with_name("test2")?;
3600 let right = LogicalPlanBuilder::from(right_table_scan)
3601 .project(vec![col("a"), col("b")])?
3602 .build()?;
3603 let plan = LogicalPlanBuilder::from(left)
3604 .join(
3605 right,
3606 JoinType::LeftSemi,
3607 (
3608 vec![Column::from_qualified_name("test1.a")],
3609 vec![Column::from_qualified_name("test2.a")],
3610 ),
3611 Some(
3612 col("test1.b")
3613 .gt(lit(1u32))
3614 .and(col("test2.b").gt(lit(2u32))),
3615 ),
3616 )?
3617 .build()?;
3618
3619 assert_snapshot!(plan,
3621 @r"
3622 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3623 TableScan: test1
3624 Projection: test2.a, test2.b
3625 TableScan: test2
3626 ",
3627 );
3628 assert_optimized_plan_equal!(
3630 plan,
3631 @r"
3632 LeftSemi Join: test1.a = test2.a
3633 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3634 Projection: test2.a, test2.b
3635 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3636 "
3637 )
3638 }
3639
3640 #[test]
3641 fn right_semi_join() -> Result<()> {
3642 let left = test_table_scan_with_name("test1")?;
3643 let right_table_scan = test_table_scan_with_name("test2")?;
3644 let right = LogicalPlanBuilder::from(right_table_scan)
3645 .project(vec![col("a"), col("b")])?
3646 .build()?;
3647 let plan = LogicalPlanBuilder::from(left)
3648 .join(
3649 right,
3650 JoinType::RightSemi,
3651 (
3652 vec![Column::from_qualified_name("test1.a")],
3653 vec![Column::from_qualified_name("test2.a")],
3654 ),
3655 None,
3656 )?
3657 .filter(col("test1.a").lt_eq(lit(1i64)))?
3658 .build()?;
3659
3660 assert_snapshot!(plan,
3662 @r"
3663 Filter: test1.a <= Int64(1)
3664 RightSemi Join: test1.a = test2.a
3665 TableScan: test1
3666 Projection: test2.a, test2.b
3667 TableScan: test2
3668 ",
3669 );
3670 assert_optimized_plan_equal!(
3672 plan,
3673 @r"
3674 Filter: test1.a <= Int64(1)
3675 RightSemi Join: test1.a = test2.a
3676 TableScan: test1
3677 Projection: test2.a, test2.b
3678 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3679 "
3680 )
3681 }
3682
3683 #[test]
3684 fn right_semi_join_with_filters() -> Result<()> {
3685 let left = test_table_scan_with_name("test1")?;
3686 let right_table_scan = test_table_scan_with_name("test2")?;
3687 let right = LogicalPlanBuilder::from(right_table_scan)
3688 .project(vec![col("a"), col("b")])?
3689 .build()?;
3690 let plan = LogicalPlanBuilder::from(left)
3691 .join(
3692 right,
3693 JoinType::RightSemi,
3694 (
3695 vec![Column::from_qualified_name("test1.a")],
3696 vec![Column::from_qualified_name("test2.a")],
3697 ),
3698 Some(
3699 col("test1.b")
3700 .gt(lit(1u32))
3701 .and(col("test2.b").gt(lit(2u32))),
3702 ),
3703 )?
3704 .build()?;
3705
3706 assert_snapshot!(plan,
3708 @r"
3709 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3710 TableScan: test1
3711 Projection: test2.a, test2.b
3712 TableScan: test2
3713 ",
3714 );
3715 assert_optimized_plan_equal!(
3717 plan,
3718 @r"
3719 RightSemi Join: test1.a = test2.a
3720 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3721 Projection: test2.a, test2.b
3722 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3723 "
3724 )
3725 }
3726
3727 #[test]
3728 fn left_anti_join() -> Result<()> {
3729 let table_scan = test_table_scan_with_name("test1")?;
3730 let left = LogicalPlanBuilder::from(table_scan)
3731 .project(vec![col("a"), col("b")])?
3732 .build()?;
3733 let right_table_scan = test_table_scan_with_name("test2")?;
3734 let right = LogicalPlanBuilder::from(right_table_scan)
3735 .project(vec![col("a"), col("b")])?
3736 .build()?;
3737 let plan = LogicalPlanBuilder::from(left)
3738 .join(
3739 right,
3740 JoinType::LeftAnti,
3741 (
3742 vec![Column::from_qualified_name("test1.a")],
3743 vec![Column::from_qualified_name("test2.a")],
3744 ),
3745 None,
3746 )?
3747 .filter(col("test2.a").gt(lit(2u32)))?
3748 .build()?;
3749
3750 assert_snapshot!(plan,
3752 @r"
3753 Filter: test2.a > UInt32(2)
3754 LeftAnti Join: test1.a = test2.a
3755 Projection: test1.a, test1.b
3756 TableScan: test1
3757 Projection: test2.a, test2.b
3758 TableScan: test2
3759 ",
3760 );
3761 assert_optimized_plan_equal!(
3763 plan,
3764 @r"
3765 Filter: test2.a > UInt32(2)
3766 LeftAnti Join: test1.a = test2.a
3767 Projection: test1.a, test1.b
3768 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3769 Projection: test2.a, test2.b
3770 TableScan: test2
3771 "
3772 )
3773 }
3774
3775 #[test]
3776 fn left_anti_join_with_filters() -> Result<()> {
3777 let table_scan = test_table_scan_with_name("test1")?;
3778 let left = LogicalPlanBuilder::from(table_scan)
3779 .project(vec![col("a"), col("b")])?
3780 .build()?;
3781 let right_table_scan = test_table_scan_with_name("test2")?;
3782 let right = LogicalPlanBuilder::from(right_table_scan)
3783 .project(vec![col("a"), col("b")])?
3784 .build()?;
3785 let plan = LogicalPlanBuilder::from(left)
3786 .join(
3787 right,
3788 JoinType::LeftAnti,
3789 (
3790 vec![Column::from_qualified_name("test1.a")],
3791 vec![Column::from_qualified_name("test2.a")],
3792 ),
3793 Some(
3794 col("test1.b")
3795 .gt(lit(1u32))
3796 .and(col("test2.b").gt(lit(2u32))),
3797 ),
3798 )?
3799 .build()?;
3800
3801 assert_snapshot!(plan,
3803 @r"
3804 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3805 Projection: test1.a, test1.b
3806 TableScan: test1
3807 Projection: test2.a, test2.b
3808 TableScan: test2
3809 ",
3810 );
3811 assert_optimized_plan_equal!(
3813 plan,
3814 @r"
3815 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3816 Projection: test1.a, test1.b
3817 TableScan: test1
3818 Projection: test2.a, test2.b
3819 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3820 "
3821 )
3822 }
3823
3824 #[test]
3825 fn right_anti_join() -> Result<()> {
3826 let table_scan = test_table_scan_with_name("test1")?;
3827 let left = LogicalPlanBuilder::from(table_scan)
3828 .project(vec![col("a"), col("b")])?
3829 .build()?;
3830 let right_table_scan = test_table_scan_with_name("test2")?;
3831 let right = LogicalPlanBuilder::from(right_table_scan)
3832 .project(vec![col("a"), col("b")])?
3833 .build()?;
3834 let plan = LogicalPlanBuilder::from(left)
3835 .join(
3836 right,
3837 JoinType::RightAnti,
3838 (
3839 vec![Column::from_qualified_name("test1.a")],
3840 vec![Column::from_qualified_name("test2.a")],
3841 ),
3842 None,
3843 )?
3844 .filter(col("test1.a").gt(lit(2u32)))?
3845 .build()?;
3846
3847 assert_snapshot!(plan,
3849 @r"
3850 Filter: test1.a > UInt32(2)
3851 RightAnti Join: test1.a = test2.a
3852 Projection: test1.a, test1.b
3853 TableScan: test1
3854 Projection: test2.a, test2.b
3855 TableScan: test2
3856 ",
3857 );
3858 assert_optimized_plan_equal!(
3860 plan,
3861 @r"
3862 Filter: test1.a > UInt32(2)
3863 RightAnti Join: test1.a = test2.a
3864 Projection: test1.a, test1.b
3865 TableScan: test1
3866 Projection: test2.a, test2.b
3867 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3868 "
3869 )
3870 }
3871
3872 #[test]
3873 fn right_anti_join_with_filters() -> Result<()> {
3874 let table_scan = test_table_scan_with_name("test1")?;
3875 let left = LogicalPlanBuilder::from(table_scan)
3876 .project(vec![col("a"), col("b")])?
3877 .build()?;
3878 let right_table_scan = test_table_scan_with_name("test2")?;
3879 let right = LogicalPlanBuilder::from(right_table_scan)
3880 .project(vec![col("a"), col("b")])?
3881 .build()?;
3882 let plan = LogicalPlanBuilder::from(left)
3883 .join(
3884 right,
3885 JoinType::RightAnti,
3886 (
3887 vec![Column::from_qualified_name("test1.a")],
3888 vec![Column::from_qualified_name("test2.a")],
3889 ),
3890 Some(
3891 col("test1.b")
3892 .gt(lit(1u32))
3893 .and(col("test2.b").gt(lit(2u32))),
3894 ),
3895 )?
3896 .build()?;
3897
3898 assert_snapshot!(plan,
3900 @r"
3901 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3902 Projection: test1.a, test1.b
3903 TableScan: test1
3904 Projection: test2.a, test2.b
3905 TableScan: test2
3906 ",
3907 );
3908 assert_optimized_plan_equal!(
3910 plan,
3911 @r"
3912 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3913 Projection: test1.a, test1.b
3914 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3915 Projection: test2.a, test2.b
3916 TableScan: test2
3917 "
3918 )
3919 }
3920
3921 #[derive(Debug, PartialEq, Eq, Hash)]
3922 struct TestScalarUDF {
3923 signature: Signature,
3924 }
3925
3926 impl ScalarUDFImpl for TestScalarUDF {
3927 fn as_any(&self) -> &dyn Any {
3928 self
3929 }
3930 fn name(&self) -> &str {
3931 "TestScalarUDF"
3932 }
3933
3934 fn signature(&self) -> &Signature {
3935 &self.signature
3936 }
3937
3938 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3939 Ok(DataType::Int32)
3940 }
3941
3942 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3943 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3944 }
3945 }
3946
3947 #[test]
3948 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3949 let table_scan = test_table_scan_with_name("test1")?;
3951 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3952 signature: Signature::exact(vec![], Volatility::Volatile),
3953 });
3954 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3955
3956 let plan = LogicalPlanBuilder::from(table_scan)
3957 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3958 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3959 .alias("t")?
3960 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3961 .project(vec![col("t.a"), col("t.r")])?
3962 .build()?;
3963
3964 assert_snapshot!(plan,
3965 @r"
3966 Projection: t.a, t.r
3967 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3968 SubqueryAlias: t
3969 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3970 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3971 TableScan: test1
3972 ",
3973 );
3974 assert_optimized_plan_equal!(
3975 plan,
3976 @r"
3977 Projection: t.a, t.r
3978 SubqueryAlias: t
3979 Filter: r > Float64(0.5)
3980 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3981 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3982 TableScan: test1, full_filters=[test1.a > Int32(5)]
3983 "
3984 )
3985 }
3986
3987 #[test]
3988 fn test_push_down_volatile_function_in_join() -> Result<()> {
3989 let table_scan = test_table_scan_with_name("test1")?;
3991 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3992 signature: Signature::exact(vec![], Volatility::Volatile),
3993 });
3994 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3995 let left = LogicalPlanBuilder::from(table_scan).build()?;
3996 let right_table_scan = test_table_scan_with_name("test2")?;
3997 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3998 let plan = LogicalPlanBuilder::from(left)
3999 .join(
4000 right,
4001 JoinType::Inner,
4002 (
4003 vec![Column::from_qualified_name("test1.a")],
4004 vec![Column::from_qualified_name("test2.a")],
4005 ),
4006 None,
4007 )?
4008 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4009 .alias("t")?
4010 .filter(col("t.r").gt(lit(0.8)))?
4011 .project(vec![col("t.a"), col("t.r")])?
4012 .build()?;
4013
4014 assert_snapshot!(plan,
4015 @r"
4016 Projection: t.a, t.r
4017 Filter: t.r > Float64(0.8)
4018 SubqueryAlias: t
4019 Projection: test1.a AS a, TestScalarUDF() AS r
4020 Inner Join: test1.a = test2.a
4021 TableScan: test1
4022 TableScan: test2
4023 ",
4024 );
4025 assert_optimized_plan_equal!(
4026 plan,
4027 @r"
4028 Projection: t.a, t.r
4029 SubqueryAlias: t
4030 Filter: r > Float64(0.8)
4031 Projection: test1.a AS a, TestScalarUDF() AS r
4032 Inner Join: test1.a = test2.a
4033 TableScan: test1
4034 TableScan: test2
4035 "
4036 )
4037 }
4038
4039 #[test]
4040 fn test_push_down_volatile_table_scan() -> Result<()> {
4041 let table_scan = test_table_scan()?;
4043 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4044 signature: Signature::exact(vec![], Volatility::Volatile),
4045 });
4046 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4047 let plan = LogicalPlanBuilder::from(table_scan)
4048 .project(vec![col("a"), col("b")])?
4049 .filter(expr.gt(lit(0.1)))?
4050 .build()?;
4051
4052 assert_snapshot!(plan,
4053 @r"
4054 Filter: TestScalarUDF() > Float64(0.1)
4055 Projection: test.a, test.b
4056 TableScan: test
4057 ",
4058 );
4059 assert_optimized_plan_equal!(
4060 plan,
4061 @r"
4062 Projection: test.a, test.b
4063 Filter: TestScalarUDF() > Float64(0.1)
4064 TableScan: test
4065 "
4066 )
4067 }
4068
4069 #[test]
4070 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4071 let table_scan = test_table_scan()?;
4073 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4074 signature: Signature::exact(vec![], Volatility::Volatile),
4075 });
4076 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4077 let plan = LogicalPlanBuilder::from(table_scan)
4078 .project(vec![col("a"), col("b")])?
4079 .filter(
4080 expr.gt(lit(0.1))
4081 .and(col("t.a").gt(lit(5)))
4082 .and(col("t.b").gt(lit(10))),
4083 )?
4084 .build()?;
4085
4086 assert_snapshot!(plan,
4087 @r"
4088 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4089 Projection: test.a, test.b
4090 TableScan: test
4091 ",
4092 );
4093 assert_optimized_plan_equal!(
4094 plan,
4095 @r"
4096 Projection: test.a, test.b
4097 Filter: TestScalarUDF() > Float64(0.1)
4098 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4099 "
4100 )
4101 }
4102
4103 #[test]
4104 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4105 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4107 signature: Signature::exact(vec![], Volatility::Volatile),
4108 });
4109 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4110 let plan = table_scan_with_pushdown_provider_builder(
4111 TableProviderFilterPushDown::Unsupported,
4112 vec![],
4113 None,
4114 )?
4115 .project(vec![col("a"), col("b")])?
4116 .filter(
4117 expr.gt(lit(0.1))
4118 .and(col("t.a").gt(lit(5)))
4119 .and(col("t.b").gt(lit(10))),
4120 )?
4121 .build()?;
4122
4123 assert_snapshot!(plan,
4124 @r"
4125 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4126 Projection: a, b
4127 TableScan: test
4128 ",
4129 );
4130 assert_optimized_plan_equal!(
4131 plan,
4132 @r"
4133 Projection: a, b
4134 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4135 TableScan: test
4136 "
4137 )
4138 }
4139
4140 #[test]
4141 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4142 #[derive(Debug, Hash, Eq, PartialEq)]
4144 struct TestUserNode {
4145 schema: DFSchemaRef,
4146 }
4147
4148 impl PartialOrd for TestUserNode {
4149 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4150 None
4151 }
4152 }
4153
4154 impl TestUserNode {
4155 fn new() -> Self {
4156 let schema = Arc::new(
4157 DFSchema::new_with_metadata(
4158 vec![(None, Field::new("a", DataType::Int64, false).into())],
4159 Default::default(),
4160 )
4161 .unwrap(),
4162 );
4163
4164 Self { schema }
4165 }
4166 }
4167
4168 impl UserDefinedLogicalNodeCore for TestUserNode {
4169 fn name(&self) -> &str {
4170 "test_node"
4171 }
4172
4173 fn inputs(&self) -> Vec<&LogicalPlan> {
4174 vec![]
4175 }
4176
4177 fn schema(&self) -> &DFSchemaRef {
4178 &self.schema
4179 }
4180
4181 fn expressions(&self) -> Vec<Expr> {
4182 vec![]
4183 }
4184
4185 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4186 write!(f, "TestUserNode")
4187 }
4188
4189 fn with_exprs_and_inputs(
4190 &self,
4191 exprs: Vec<Expr>,
4192 inputs: Vec<LogicalPlan>,
4193 ) -> Result<Self> {
4194 assert!(exprs.is_empty());
4195 assert!(inputs.is_empty());
4196 Ok(Self {
4197 schema: Arc::clone(&self.schema),
4198 })
4199 }
4200 }
4201
4202 let node = LogicalPlan::Extension(Extension {
4204 node: Arc::new(TestUserNode::new()),
4205 });
4206
4207 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4208
4209 assert_snapshot!(plan,
4211 @r"
4212 Filter: Boolean(false)
4213 TestUserNode
4214 ",
4215 );
4216 assert_optimized_plan_equal!(
4218 plan,
4219 @r"
4220 Filter: Boolean(false)
4221 TestUserNode
4222 "
4223 )
4224 }
4225}