1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31 Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
32 internal_err, plan_err, qualified_name,
33};
34use datafusion_expr::expr::WindowFunction;
35use datafusion_expr::expr_rewriter::replace_col;
36use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
37use datafusion_expr::utils::{
38 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
39};
40use datafusion_expr::{
41 BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
42};
43
44use crate::optimizer::ApplyOrder;
45use crate::simplify_expressions::simplify_predicates;
46use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
47use crate::{OptimizerConfig, OptimizerRule};
48use datafusion_expr::ExpressionPlacement;
49
50#[derive(Default, Debug)]
138pub struct PushDownFilter {}
139
140pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
165 match join_type {
166 JoinType::Inner => (true, true),
167 JoinType::Left => (true, false),
168 JoinType::Right => (false, true),
169 JoinType::Full => (false, false),
170 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
173 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
176 }
177}
178
179pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
189 match join_type {
190 JoinType::Inner => (true, true),
191 JoinType::Left => (false, true),
192 JoinType::Right => (true, false),
193 JoinType::Full => (false, false),
194 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
195 JoinType::LeftAnti => (false, true),
196 JoinType::RightAnti => (true, false),
197 JoinType::LeftMark => (false, true),
198 JoinType::RightMark => (true, false),
199 }
200}
201
202#[derive(Debug)]
205struct ColumnChecker<'a> {
206 left_schema: &'a DFSchema,
208 left_columns: Option<HashSet<Column>>,
210 right_schema: &'a DFSchema,
212 right_columns: Option<HashSet<Column>>,
214}
215
216impl<'a> ColumnChecker<'a> {
217 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
218 Self {
219 left_schema,
220 left_columns: None,
221 right_schema,
222 right_columns: None,
223 }
224 }
225
226 fn is_left_only(&mut self, predicate: &Expr) -> bool {
228 if self.left_columns.is_none() {
229 self.left_columns = Some(schema_columns(self.left_schema));
230 }
231 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
232 }
233
234 fn is_right_only(&mut self, predicate: &Expr) -> bool {
236 if self.right_columns.is_none() {
237 self.right_columns = Some(schema_columns(self.right_schema));
238 }
239 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
240 }
241}
242
243fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
245 schema
246 .iter()
247 .flat_map(|(qualifier, field)| {
248 [
249 Column::new(qualifier.cloned(), field.name()),
250 Column::new_unqualified(field.name()),
252 ]
253 })
254 .collect::<HashSet<_>>()
255}
256
257fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
259 let mut is_evaluate = true;
260 predicate.apply(|expr| match expr {
261 Expr::Column(_)
262 | Expr::Literal(_, _)
263 | Expr::Placeholder(_)
264 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
265 Expr::Exists { .. }
266 | Expr::InSubquery(_)
267 | Expr::SetComparison(_)
268 | Expr::ScalarSubquery(_)
269 | Expr::OuterReferenceColumn(_, _)
270 | Expr::Unnest(_) => {
271 is_evaluate = false;
272 Ok(TreeNodeRecursion::Stop)
273 }
274 Expr::Alias(_)
275 | Expr::BinaryExpr(_)
276 | Expr::Like(_)
277 | Expr::SimilarTo(_)
278 | Expr::Not(_)
279 | Expr::IsNotNull(_)
280 | Expr::IsNull(_)
281 | Expr::IsTrue(_)
282 | Expr::IsFalse(_)
283 | Expr::IsUnknown(_)
284 | Expr::IsNotTrue(_)
285 | Expr::IsNotFalse(_)
286 | Expr::IsNotUnknown(_)
287 | Expr::Negative(_)
288 | Expr::Between(_)
289 | Expr::Case(_)
290 | Expr::Cast(_)
291 | Expr::TryCast(_)
292 | Expr::InList { .. }
293 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
294 #[expect(deprecated)]
296 Expr::AggregateFunction(_)
297 | Expr::WindowFunction(_)
298 | Expr::Wildcard { .. }
299 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
300 })?;
301 Ok(is_evaluate)
302}
303
304fn extract_or_clauses_for_join<'a>(
338 filters: &'a [Expr],
339 schema: &'a DFSchema,
340) -> impl Iterator<Item = Expr> + 'a {
341 let schema_columns = schema_columns(schema);
342
343 filters.iter().filter_map(move |expr| {
345 if let Expr::BinaryExpr(BinaryExpr {
346 left,
347 op: Operator::Or,
348 right,
349 }) = expr
350 {
351 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
352 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
353
354 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
356 return Some(or(left_expr, right_expr));
357 }
358 }
359 None
360 })
361}
362
363fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
375 let mut predicate = None;
376
377 match expr {
378 Expr::BinaryExpr(BinaryExpr {
379 left: l_expr,
380 op: Operator::Or,
381 right: r_expr,
382 }) => {
383 let l_expr = extract_or_clause(l_expr, schema_columns);
384 let r_expr = extract_or_clause(r_expr, schema_columns);
385
386 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
387 predicate = Some(or(l_expr, r_expr));
388 }
389 }
390 Expr::BinaryExpr(BinaryExpr {
391 left: l_expr,
392 op: Operator::And,
393 right: r_expr,
394 }) => {
395 let l_expr = extract_or_clause(l_expr, schema_columns);
396 let r_expr = extract_or_clause(r_expr, schema_columns);
397
398 match (l_expr, r_expr) {
399 (Some(l_expr), Some(r_expr)) => {
400 predicate = Some(and(l_expr, r_expr));
401 }
402 (Some(l_expr), None) => {
403 predicate = Some(l_expr);
404 }
405 (None, Some(r_expr)) => {
406 predicate = Some(r_expr);
407 }
408 (None, None) => {
409 predicate = None;
410 }
411 }
412 }
413 _ => {
414 if has_all_column_refs(expr, schema_columns) {
415 predicate = Some(expr.clone());
416 }
417 }
418 }
419
420 predicate
421}
422
423fn push_down_all_join(
425 predicates: Vec<Expr>,
426 inferred_join_predicates: Vec<Expr>,
427 mut join: Join,
428 on_filter: Vec<Expr>,
429) -> Result<Transformed<LogicalPlan>> {
430 let is_inner_join = join.join_type == JoinType::Inner;
431 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
433
434 let left_schema = join.left.schema();
439 let right_schema = join.right.schema();
440 let mut left_push = vec![];
441 let mut right_push = vec![];
442 let mut keep_predicates = vec![];
443 let mut join_conditions = vec![];
444 let mut checker = ColumnChecker::new(left_schema, right_schema);
445 for predicate in predicates {
446 if left_preserved && checker.is_left_only(&predicate) {
447 left_push.push(predicate);
448 } else if right_preserved && checker.is_right_only(&predicate) {
449 right_push.push(predicate);
450 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
451 join_conditions.push(predicate);
454 } else {
455 keep_predicates.push(predicate);
456 }
457 }
458
459 for predicate in inferred_join_predicates {
461 if checker.is_left_only(&predicate) {
462 left_push.push(predicate);
463 } else if checker.is_right_only(&predicate) {
464 right_push.push(predicate);
465 }
466 }
467
468 let mut on_filter_join_conditions = vec![];
469 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
470
471 if !on_filter.is_empty() {
472 for on in on_filter {
473 if on_left_preserved && checker.is_left_only(&on) {
474 left_push.push(on)
475 } else if on_right_preserved && checker.is_right_only(&on) {
476 right_push.push(on)
477 } else {
478 on_filter_join_conditions.push(on)
479 }
480 }
481 }
482
483 if left_preserved {
486 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
487 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
488 }
489 if right_preserved {
490 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
491 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
492 }
493
494 if on_left_preserved {
497 left_push.extend(extract_or_clauses_for_join(
498 &on_filter_join_conditions,
499 left_schema,
500 ));
501 }
502 if on_right_preserved {
503 right_push.extend(extract_or_clauses_for_join(
504 &on_filter_join_conditions,
505 right_schema,
506 ));
507 }
508
509 if let Some(predicate) = conjunction(left_push) {
510 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
511 }
512 if let Some(predicate) = conjunction(right_push) {
513 join.right =
514 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
515 }
516
517 join_conditions.extend(on_filter_join_conditions);
519 join.filter = conjunction(join_conditions);
520
521 let plan = LogicalPlan::Join(join);
523 let plan = if let Some(predicate) = conjunction(keep_predicates) {
524 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
525 } else {
526 plan
527 };
528 Ok(Transformed::yes(plan))
529}
530
531fn push_down_join(
532 join: Join,
533 parent_predicate: Option<&Expr>,
534) -> Result<Transformed<LogicalPlan>> {
535 let predicates = parent_predicate
537 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
538
539 let on_filters = join
541 .filter
542 .as_ref()
543 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
544
545 let inferred_join_predicates =
547 infer_join_predicates(&join, &predicates, &on_filters)?;
548
549 if on_filters.is_empty()
550 && predicates.is_empty()
551 && inferred_join_predicates.is_empty()
552 {
553 return Ok(Transformed::no(LogicalPlan::Join(join)));
554 }
555
556 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
557}
558
559fn infer_join_predicates(
569 join: &Join,
570 predicates: &[Expr],
571 on_filters: &[Expr],
572) -> Result<Vec<Expr>> {
573 let join_col_keys = join
575 .on
576 .iter()
577 .filter_map(|(l, r)| {
578 let left_col = l.try_as_col()?;
579 let right_col = r.try_as_col()?;
580 Some((left_col, right_col))
581 })
582 .collect::<Vec<_>>();
583
584 let join_type = join.join_type;
585
586 let mut inferred_predicates = InferredPredicates::new(join_type);
587
588 infer_join_predicates_from_predicates(
589 &join_col_keys,
590 predicates,
591 &mut inferred_predicates,
592 )?;
593
594 infer_join_predicates_from_on_filters(
595 &join_col_keys,
596 join_type,
597 on_filters,
598 &mut inferred_predicates,
599 )?;
600
601 Ok(inferred_predicates.predicates)
602}
603
604struct InferredPredicates {
613 predicates: Vec<Expr>,
614 is_inner_join: bool,
615}
616
617impl InferredPredicates {
618 fn new(join_type: JoinType) -> Self {
619 Self {
620 predicates: vec![],
621 is_inner_join: join_type == JoinType::Inner,
622 }
623 }
624
625 fn try_build_predicate(
626 &mut self,
627 predicate: Expr,
628 replace_map: &HashMap<&Column, &Column>,
629 ) -> Result<()> {
630 if self.is_inner_join
631 || matches!(
632 is_restrict_null_predicate(
633 predicate.clone(),
634 replace_map.keys().cloned()
635 ),
636 Ok(true)
637 )
638 {
639 self.predicates.push(replace_col(predicate, replace_map)?);
640 }
641
642 Ok(())
643 }
644}
645
646fn infer_join_predicates_from_predicates(
655 join_col_keys: &[(&Column, &Column)],
656 predicates: &[Expr],
657 inferred_predicates: &mut InferredPredicates,
658) -> Result<()> {
659 infer_join_predicates_impl::<true, true>(
660 join_col_keys,
661 predicates,
662 inferred_predicates,
663 )
664}
665
666fn infer_join_predicates_from_on_filters(
678 join_col_keys: &[(&Column, &Column)],
679 join_type: JoinType,
680 on_filters: &[Expr],
681 inferred_predicates: &mut InferredPredicates,
682) -> Result<()> {
683 match join_type {
684 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
685 JoinType::Inner => infer_join_predicates_impl::<true, true>(
686 join_col_keys,
687 on_filters,
688 inferred_predicates,
689 ),
690 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
691 infer_join_predicates_impl::<true, false>(
692 join_col_keys,
693 on_filters,
694 inferred_predicates,
695 )
696 }
697 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
698 infer_join_predicates_impl::<false, true>(
699 join_col_keys,
700 on_filters,
701 inferred_predicates,
702 )
703 }
704 }
705}
706
707fn infer_join_predicates_impl<
723 const ENABLE_LEFT_TO_RIGHT: bool,
724 const ENABLE_RIGHT_TO_LEFT: bool,
725>(
726 join_col_keys: &[(&Column, &Column)],
727 input_predicates: &[Expr],
728 inferred_predicates: &mut InferredPredicates,
729) -> Result<()> {
730 for predicate in input_predicates {
731 let mut join_cols_to_replace = HashMap::new();
732
733 for &col in &predicate.column_refs() {
734 for (l, r) in join_col_keys.iter() {
735 if ENABLE_LEFT_TO_RIGHT && col == *l {
736 join_cols_to_replace.insert(col, *r);
737 break;
738 }
739 if ENABLE_RIGHT_TO_LEFT && col == *r {
740 join_cols_to_replace.insert(col, *l);
741 break;
742 }
743 }
744 }
745 if join_cols_to_replace.is_empty() {
746 continue;
747 }
748
749 inferred_predicates
750 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
751 }
752 Ok(())
753}
754
755impl OptimizerRule for PushDownFilter {
756 fn name(&self) -> &str {
757 "push_down_filter"
758 }
759
760 fn apply_order(&self) -> Option<ApplyOrder> {
761 Some(ApplyOrder::TopDown)
762 }
763
764 fn supports_rewrite(&self) -> bool {
765 true
766 }
767
768 fn rewrite(
769 &self,
770 plan: LogicalPlan,
771 config: &dyn OptimizerConfig,
772 ) -> Result<Transformed<LogicalPlan>> {
773 let _ = config.options();
774 if let LogicalPlan::Join(join) = plan {
775 return push_down_join(join, None);
776 };
777
778 let plan_schema = Arc::clone(plan.schema());
779
780 let LogicalPlan::Filter(mut filter) = plan else {
781 return Ok(Transformed::no(plan));
782 };
783
784 let predicate = split_conjunction_owned(filter.predicate.clone());
785 let old_predicate_len = predicate.len();
786 let new_predicates = simplify_predicates(predicate)?;
787 if old_predicate_len != new_predicates.len() {
788 let Some(new_predicate) = conjunction(new_predicates) else {
789 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
792 };
793 filter.predicate = new_predicate;
794 }
795
796 match Arc::unwrap_or_clone(filter.input) {
797 LogicalPlan::Filter(child_filter) => {
798 let parents_predicates = split_conjunction_owned(filter.predicate);
799
800 let child_predicates = split_conjunction_owned(child_filter.predicate);
802 let new_predicates = parents_predicates
803 .into_iter()
804 .chain(child_predicates)
805 .collect::<IndexSet<_>>()
807 .into_iter()
808 .collect::<Vec<_>>();
809
810 let Some(new_predicate) = conjunction(new_predicates) else {
811 return plan_err!("at least one expression exists");
812 };
813 let new_filter = LogicalPlan::Filter(Filter::try_new(
814 new_predicate,
815 child_filter.input,
816 )?);
817 self.rewrite(new_filter, config)
818 }
819 LogicalPlan::Repartition(repartition) => {
820 let new_filter =
821 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
822 .map(LogicalPlan::Filter)?;
823 insert_below(LogicalPlan::Repartition(repartition), new_filter)
824 }
825 LogicalPlan::Distinct(distinct) => {
826 let new_filter =
827 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
828 .map(LogicalPlan::Filter)?;
829 insert_below(LogicalPlan::Distinct(distinct), new_filter)
830 }
831 LogicalPlan::Sort(sort) => {
832 let new_filter =
833 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
834 .map(LogicalPlan::Filter)?;
835 insert_below(LogicalPlan::Sort(sort), new_filter)
836 }
837 LogicalPlan::SubqueryAlias(subquery_alias) => {
838 let mut replace_map = HashMap::new();
839 for (i, (qualifier, field)) in
840 subquery_alias.input.schema().iter().enumerate()
841 {
842 let (sub_qualifier, sub_field) =
843 subquery_alias.schema.qualified_field(i);
844 replace_map.insert(
845 qualified_name(sub_qualifier, sub_field.name()),
846 Expr::Column(Column::new(qualifier.cloned(), field.name())),
847 );
848 }
849 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
850
851 let new_filter = LogicalPlan::Filter(Filter::try_new(
852 new_predicate,
853 Arc::clone(&subquery_alias.input),
854 )?);
855 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
856 }
857 LogicalPlan::Projection(projection) => {
858 let predicates = split_conjunction_owned(filter.predicate.clone());
859 let (new_projection, keep_predicate) =
860 rewrite_projection(predicates, projection)?;
861 if new_projection.transformed {
862 match keep_predicate {
863 None => Ok(new_projection),
864 Some(keep_predicate) => new_projection.map_data(|child_plan| {
865 Filter::try_new(keep_predicate, Arc::new(child_plan))
866 .map(LogicalPlan::Filter)
867 }),
868 }
869 } else {
870 filter.input = Arc::new(new_projection.data);
871 Ok(Transformed::no(LogicalPlan::Filter(filter)))
872 }
873 }
874 LogicalPlan::Unnest(mut unnest) => {
875 let predicates = split_conjunction_owned(filter.predicate.clone());
876 let mut non_unnest_predicates = vec![];
877 let mut unnest_predicates = vec![];
878 let mut unnest_struct_columns = vec![];
879
880 for idx in &unnest.struct_type_columns {
881 let (sub_qualifier, field) =
882 unnest.input.schema().qualified_field(*idx);
883 let field_name = field.name().clone();
884
885 if let DataType::Struct(children) = field.data_type() {
886 for child in children {
887 let child_name = child.name().clone();
888 unnest_struct_columns.push(Column::new(
889 sub_qualifier.cloned(),
890 format!("{field_name}.{child_name}"),
891 ));
892 }
893 }
894 }
895
896 for predicate in predicates {
897 let mut accum: HashSet<Column> = HashSet::new();
899 expr_to_columns(&predicate, &mut accum)?;
900
901 let contains_list_columns =
902 unnest.list_type_columns.iter().any(|(_, unnest_list)| {
903 accum.contains(&unnest_list.output_column)
904 });
905 let contains_struct_columns =
906 unnest_struct_columns.iter().any(|c| accum.contains(c));
907
908 if contains_list_columns || contains_struct_columns {
909 unnest_predicates.push(predicate);
910 } else {
911 non_unnest_predicates.push(predicate);
912 }
913 }
914
915 if non_unnest_predicates.is_empty() {
918 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
919 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
920 }
921
922 let unnest_input = std::mem::take(&mut unnest.input);
931
932 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
933 conjunction(non_unnest_predicates).unwrap(), unnest_input,
935 )?);
936
937 let unnest_plan =
941 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
942
943 match conjunction(unnest_predicates) {
944 None => Ok(unnest_plan),
945 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
946 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
947 ))),
948 }
949 }
950 LogicalPlan::Union(ref union) => {
951 let mut inputs = Vec::with_capacity(union.inputs.len());
952 for input in &union.inputs {
953 let mut replace_map = HashMap::new();
954 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
955 let (union_qualifier, union_field) =
956 union.schema.qualified_field(i);
957 replace_map.insert(
958 qualified_name(union_qualifier, union_field.name()),
959 Expr::Column(Column::new(qualifier.cloned(), field.name())),
960 );
961 }
962
963 let push_predicate =
964 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
965 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
966 push_predicate,
967 Arc::clone(input),
968 )?)))
969 }
970 Ok(Transformed::yes(LogicalPlan::Union(Union {
971 inputs,
972 schema: Arc::clone(&plan_schema),
973 })))
974 }
975 LogicalPlan::Aggregate(agg) => {
976 let group_expr_columns = agg
978 .group_expr
979 .iter()
980 .map(|e| {
981 let (relation, name) = e.qualified_name();
982 Column::new(relation, name)
983 })
984 .collect::<HashSet<_>>();
985
986 let predicates = split_conjunction_owned(filter.predicate);
987
988 let mut keep_predicates = vec![];
989 let mut push_predicates = vec![];
990 for expr in predicates {
991 let cols = expr.column_refs();
992 if cols.iter().all(|c| group_expr_columns.contains(c)) {
993 push_predicates.push(expr);
994 } else {
995 keep_predicates.push(expr);
996 }
997 }
998
999 let mut replace_map = HashMap::new();
1003 for expr in &agg.group_expr {
1004 replace_map.insert(expr.schema_name().to_string(), expr.clone());
1005 }
1006 let replaced_push_predicates = push_predicates
1007 .into_iter()
1008 .map(|expr| replace_cols_by_name(expr, &replace_map))
1009 .collect::<Result<Vec<_>>>()?;
1010
1011 let agg_input = Arc::clone(&agg.input);
1012 Transformed::yes(LogicalPlan::Aggregate(agg))
1013 .transform_data(|new_plan| {
1014 if let Some(predicate) = conjunction(replaced_push_predicates) {
1016 let new_filter = make_filter(predicate, agg_input)?;
1017 insert_below(new_plan, new_filter)
1018 } else {
1019 Ok(Transformed::no(new_plan))
1020 }
1021 })?
1022 .map_data(|child_plan| {
1023 if let Some(predicate) = conjunction(keep_predicates) {
1026 make_filter(predicate, Arc::new(child_plan))
1027 } else {
1028 Ok(child_plan)
1029 }
1030 })
1031 }
1032 LogicalPlan::Window(window) => {
1043 let extract_partition_keys = |func: &WindowFunction| {
1049 func.params
1050 .partition_by
1051 .iter()
1052 .map(|c| {
1053 let (relation, name) = c.qualified_name();
1054 Column::new(relation, name)
1055 })
1056 .collect::<HashSet<_>>()
1057 };
1058 let potential_partition_keys = window
1059 .window_expr
1060 .iter()
1061 .map(|e| {
1062 match e {
1063 Expr::WindowFunction(window_func) => {
1064 extract_partition_keys(window_func)
1065 }
1066 Expr::Alias(alias) => {
1067 if let Expr::WindowFunction(window_func) =
1068 alias.expr.as_ref()
1069 {
1070 extract_partition_keys(window_func)
1071 } else {
1072 unreachable!()
1074 }
1075 }
1076 _ => {
1077 unreachable!()
1079 }
1080 }
1081 })
1082 .reduce(|a, b| &a & &b)
1085 .unwrap_or_default();
1086
1087 let predicates = split_conjunction_owned(filter.predicate);
1088 let mut keep_predicates = vec![];
1089 let mut push_predicates = vec![];
1090 for expr in predicates {
1091 let cols = expr.column_refs();
1092 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1093 push_predicates.push(expr);
1094 } else {
1095 keep_predicates.push(expr);
1096 }
1097 }
1098
1099 let window_input = Arc::clone(&window.input);
1108 Transformed::yes(LogicalPlan::Window(window))
1109 .transform_data(|new_plan| {
1110 if let Some(predicate) = conjunction(push_predicates) {
1112 let new_filter = make_filter(predicate, window_input)?;
1113 insert_below(new_plan, new_filter)
1114 } else {
1115 Ok(Transformed::no(new_plan))
1116 }
1117 })?
1118 .map_data(|child_plan| {
1119 if let Some(predicate) = conjunction(keep_predicates) {
1122 make_filter(predicate, Arc::new(child_plan))
1123 } else {
1124 Ok(child_plan)
1125 }
1126 })
1127 }
1128 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1129 LogicalPlan::TableScan(scan) => {
1130 let filter_predicates = split_conjunction(&filter.predicate);
1131
1132 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1133 filter_predicates
1134 .into_iter()
1135 .partition(|pred| pred.is_volatile());
1136
1137 let supported_filters = scan
1139 .source
1140 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1141 assert_eq_or_internal_err!(
1142 non_volatile_filters.len(),
1143 supported_filters.len(),
1144 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1145 supported_filters.len(),
1146 non_volatile_filters.len()
1147 );
1148
1149 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1151
1152 let new_scan_filters = zip
1153 .clone()
1154 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1155 .map(|(pred, _)| pred);
1156
1157 let new_scan_filters: Vec<Expr> = scan
1159 .filters
1160 .iter()
1161 .chain(new_scan_filters)
1162 .unique()
1163 .cloned()
1164 .collect();
1165
1166 let new_predicate: Vec<Expr> = zip
1168 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1169 .map(|(pred, _)| pred)
1170 .chain(volatile_filters)
1171 .cloned()
1172 .collect();
1173
1174 let new_scan = LogicalPlan::TableScan(TableScan {
1175 filters: new_scan_filters,
1176 ..scan
1177 });
1178
1179 Transformed::yes(new_scan).transform_data(|new_scan| {
1180 if let Some(predicate) = conjunction(new_predicate) {
1181 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1182 } else {
1183 Ok(Transformed::no(new_scan))
1184 }
1185 })
1186 }
1187 LogicalPlan::Extension(extension_plan) => {
1188 if extension_plan.node.inputs().is_empty() {
1191 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1192 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1193 }
1194 let prevent_cols =
1195 extension_plan.node.prevent_predicate_push_down_columns();
1196
1197 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1201 .iter()
1202 .map(|expr| {
1203 let cols = expr.column_refs();
1204 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1205 Ok(false) } else {
1207 Ok(true) }
1209 })
1210 .collect::<Result<Vec<_>>>()?;
1211
1212 if predicate_push_or_keep.iter().all(|&x| !x) {
1214 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1215 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1216 }
1217
1218 let mut keep_predicates = vec![];
1220 let mut push_predicates = vec![];
1221 for (push, expr) in predicate_push_or_keep
1222 .into_iter()
1223 .zip(split_conjunction_owned(filter.predicate).into_iter())
1224 {
1225 if !push {
1226 keep_predicates.push(expr);
1227 } else {
1228 push_predicates.push(expr);
1229 }
1230 }
1231
1232 let new_children = match conjunction(push_predicates) {
1233 Some(predicate) => extension_plan
1234 .node
1235 .inputs()
1236 .into_iter()
1237 .map(|child| {
1238 Ok(LogicalPlan::Filter(Filter::try_new(
1239 predicate.clone(),
1240 Arc::new(child.clone()),
1241 )?))
1242 })
1243 .collect::<Result<Vec<_>>>()?,
1244 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1245 };
1246 let child_plan = LogicalPlan::Extension(extension_plan);
1248 let new_extension =
1249 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1250
1251 let new_plan = match conjunction(keep_predicates) {
1252 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1253 predicate,
1254 Arc::new(new_extension),
1255 )?),
1256 None => new_extension,
1257 };
1258 Ok(Transformed::yes(new_plan))
1259 }
1260 child => {
1261 filter.input = Arc::new(child);
1262 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1263 }
1264 }
1265 }
1266}
1267
1268fn rewrite_projection(
1296 predicates: Vec<Expr>,
1297 mut projection: Projection,
1298) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1299 let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection
1306 .schema
1307 .iter()
1308 .zip(projection.expr.iter())
1309 .map(|((qualifier, field), expr)| {
1310 let expr = expr.clone().unalias();
1312
1313 (qualified_name(qualifier, field.name()), expr)
1314 })
1315 .partition(|(_, value)| {
1316 value.is_volatile()
1317 || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes
1318 });
1319
1320 let mut push_predicates = vec![];
1321 let mut keep_predicates = vec![];
1322 for expr in predicates {
1323 if contain(&expr, &non_pushable_map) {
1324 keep_predicates.push(expr);
1325 } else {
1326 push_predicates.push(expr);
1327 }
1328 }
1329
1330 match conjunction(push_predicates) {
1331 Some(expr) => {
1332 let new_filter = LogicalPlan::Filter(Filter::try_new(
1335 replace_cols_by_name(expr, &pushable_map)?,
1336 std::mem::take(&mut projection.input),
1337 )?);
1338
1339 projection.input = Arc::new(new_filter);
1340
1341 Ok((
1342 Transformed::yes(LogicalPlan::Projection(projection)),
1343 conjunction(keep_predicates),
1344 ))
1345 }
1346 None => Ok((
1347 Transformed::no(LogicalPlan::Projection(projection)),
1348 conjunction(keep_predicates),
1349 )),
1350 }
1351}
1352
1353pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1355 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1356}
1357
1358fn insert_below(
1372 plan: LogicalPlan,
1373 new_child: LogicalPlan,
1374) -> Result<Transformed<LogicalPlan>> {
1375 let mut new_child = Some(new_child);
1376 let transformed_plan = plan.map_children(|_child| {
1377 if let Some(new_child) = new_child.take() {
1378 Ok(Transformed::yes(new_child))
1379 } else {
1380 internal_err!("node had more than one input")
1382 }
1383 })?;
1384
1385 assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1387
1388 Ok(transformed_plan)
1389}
1390
1391impl PushDownFilter {
1392 #[expect(missing_docs)]
1393 pub fn new() -> Self {
1394 Self {}
1395 }
1396}
1397
1398pub fn replace_cols_by_name(
1400 e: Expr,
1401 replace_map: &HashMap<String, Expr>,
1402) -> Result<Expr> {
1403 e.transform_up(|expr| {
1404 Ok(if let Expr::Column(c) = &expr {
1405 match replace_map.get(&c.flat_name()) {
1406 Some(new_c) => Transformed::yes(new_c.clone()),
1407 None => Transformed::no(expr),
1408 }
1409 } else {
1410 Transformed::no(expr)
1411 })
1412 })
1413 .data()
1414}
1415
1416fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1418 let mut is_contain = false;
1419 e.apply(|expr| {
1420 Ok(if let Expr::Column(c) = &expr {
1421 match check_map.get(&c.flat_name()) {
1422 Some(_) => {
1423 is_contain = true;
1424 TreeNodeRecursion::Stop
1425 }
1426 None => TreeNodeRecursion::Continue,
1427 }
1428 } else {
1429 TreeNodeRecursion::Continue
1430 })
1431 })
1432 .unwrap();
1433 is_contain
1434}
1435
1436#[cfg(test)]
1437mod tests {
1438 use std::any::Any;
1439 use std::cmp::Ordering;
1440 use std::fmt::{Debug, Formatter};
1441
1442 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1443 use async_trait::async_trait;
1444
1445 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1446 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1447 use datafusion_expr::logical_plan::table_scan;
1448 use datafusion_expr::{
1449 ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1450 ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1451 UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1452 in_subquery, lit,
1453 };
1454
1455 use crate::OptimizerContext;
1456 use crate::assert_optimized_plan_eq_snapshot;
1457 use crate::optimizer::Optimizer;
1458 use crate::simplify_expressions::SimplifyExpressions;
1459 use crate::test::udfs::leaf_udf_expr;
1460 use crate::test::*;
1461 use datafusion_expr::test::function_stub::sum;
1462 use insta::assert_snapshot;
1463
1464 use super::*;
1465
1466 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1467
1468 macro_rules! assert_optimized_plan_equal {
1469 (
1470 $plan:expr,
1471 @ $expected:literal $(,)?
1472 ) => {{
1473 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1474 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1475 assert_optimized_plan_eq_snapshot!(
1476 optimizer_ctx,
1477 rules,
1478 $plan,
1479 @ $expected,
1480 )
1481 }};
1482 }
1483
1484 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1485 (
1486 $plan:expr,
1487 @ $expected:literal $(,)?
1488 ) => {{
1489 let optimizer = Optimizer::with_rules(vec![
1490 Arc::new(SimplifyExpressions::new()),
1491 Arc::new(PushDownFilter::new()),
1492 ]);
1493 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1494 assert_snapshot!(optimized_plan, @ $expected);
1495 Ok::<(), DataFusionError>(())
1496 }};
1497 }
1498
1499 #[test]
1500 fn filter_before_projection() -> Result<()> {
1501 let table_scan = test_table_scan()?;
1502 let plan = LogicalPlanBuilder::from(table_scan)
1503 .project(vec![col("a"), col("b")])?
1504 .filter(col("a").eq(lit(1i64)))?
1505 .build()?;
1506 assert_optimized_plan_equal!(
1508 plan,
1509 @r"
1510 Projection: test.a, test.b
1511 TableScan: test, full_filters=[test.a = Int64(1)]
1512 "
1513 )
1514 }
1515
1516 #[test]
1517 fn filter_after_limit() -> Result<()> {
1518 let table_scan = test_table_scan()?;
1519 let plan = LogicalPlanBuilder::from(table_scan)
1520 .project(vec![col("a"), col("b")])?
1521 .limit(0, Some(10))?
1522 .filter(col("a").eq(lit(1i64)))?
1523 .build()?;
1524 assert_optimized_plan_equal!(
1526 plan,
1527 @r"
1528 Filter: test.a = Int64(1)
1529 Limit: skip=0, fetch=10
1530 Projection: test.a, test.b
1531 TableScan: test
1532 "
1533 )
1534 }
1535
1536 #[test]
1537 fn filter_no_columns() -> Result<()> {
1538 let table_scan = test_table_scan()?;
1539 let plan = LogicalPlanBuilder::from(table_scan)
1540 .filter(lit(0i64).eq(lit(1i64)))?
1541 .build()?;
1542 assert_optimized_plan_equal!(
1543 plan,
1544 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1545 )
1546 }
1547
1548 #[test]
1549 fn filter_jump_2_plans() -> Result<()> {
1550 let table_scan = test_table_scan()?;
1551 let plan = LogicalPlanBuilder::from(table_scan)
1552 .project(vec![col("a"), col("b"), col("c")])?
1553 .project(vec![col("c"), col("b")])?
1554 .filter(col("a").eq(lit(1i64)))?
1555 .build()?;
1556 assert_optimized_plan_equal!(
1558 plan,
1559 @r"
1560 Projection: test.c, test.b
1561 Projection: test.a, test.b, test.c
1562 TableScan: test, full_filters=[test.a = Int64(1)]
1563 "
1564 )
1565 }
1566
1567 #[test]
1568 fn filter_move_agg() -> Result<()> {
1569 let table_scan = test_table_scan()?;
1570 let plan = LogicalPlanBuilder::from(table_scan)
1571 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1572 .filter(col("a").gt(lit(10i64)))?
1573 .build()?;
1574 assert_optimized_plan_equal!(
1576 plan,
1577 @r"
1578 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1579 TableScan: test, full_filters=[test.a > Int64(10)]
1580 "
1581 )
1582 }
1583
1584 #[test]
1586 fn filter_move_agg_special() -> Result<()> {
1587 let schema = Schema::new(vec![
1588 Field::new("$a", DataType::UInt32, false),
1589 Field::new("$b", DataType::UInt32, false),
1590 Field::new("$c", DataType::UInt32, false),
1591 ]);
1592 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1593
1594 let plan = LogicalPlanBuilder::from(table_scan)
1595 .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1596 .filter(col("$a").gt(lit(10i64)))?
1597 .build()?;
1598 assert_optimized_plan_equal!(
1600 plan,
1601 @r"
1602 Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1603 TableScan: test, full_filters=[test.$a > Int64(10)]
1604 "
1605 )
1606 }
1607
1608 #[test]
1609 fn filter_complex_group_by() -> Result<()> {
1610 let table_scan = test_table_scan()?;
1611 let plan = LogicalPlanBuilder::from(table_scan)
1612 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1613 .filter(col("b").gt(lit(10i64)))?
1614 .build()?;
1615 assert_optimized_plan_equal!(
1616 plan,
1617 @r"
1618 Filter: test.b > Int64(10)
1619 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1620 TableScan: test
1621 "
1622 )
1623 }
1624
1625 #[test]
1626 fn push_agg_need_replace_expr() -> Result<()> {
1627 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1628 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1629 .filter(col("test.b + test.a").gt(lit(10i64)))?
1630 .build()?;
1631 assert_optimized_plan_equal!(
1632 plan,
1633 @r"
1634 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1635 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1636 "
1637 )
1638 }
1639
1640 #[test]
1641 fn filter_keep_agg() -> Result<()> {
1642 let table_scan = test_table_scan()?;
1643 let plan = LogicalPlanBuilder::from(table_scan)
1644 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1645 .filter(col("b").gt(lit(10i64)))?
1646 .build()?;
1647 assert_optimized_plan_equal!(
1649 plan,
1650 @r"
1651 Filter: b > Int64(10)
1652 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1653 TableScan: test
1654 "
1655 )
1656 }
1657
1658 #[test]
1660 fn filter_move_window() -> Result<()> {
1661 let table_scan = test_table_scan()?;
1662
1663 let window = Expr::from(WindowFunction::new(
1664 WindowFunctionDefinition::WindowUDF(
1665 datafusion_functions_window::rank::rank_udwf(),
1666 ),
1667 vec![],
1668 ))
1669 .partition_by(vec![col("a"), col("b")])
1670 .order_by(vec![col("c").sort(true, true)])
1671 .build()
1672 .unwrap();
1673
1674 let plan = LogicalPlanBuilder::from(table_scan)
1675 .window(vec![window])?
1676 .filter(col("b").gt(lit(10i64)))?
1677 .build()?;
1678
1679 assert_optimized_plan_equal!(
1680 plan,
1681 @r"
1682 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1683 TableScan: test, full_filters=[test.b > Int64(10)]
1684 "
1685 )
1686 }
1687
1688 #[test]
1690 fn filter_window_special_identifier() -> Result<()> {
1691 let schema = Schema::new(vec![
1692 Field::new("$a", DataType::UInt32, false),
1693 Field::new("$b", DataType::UInt32, false),
1694 Field::new("$c", DataType::UInt32, false),
1695 ]);
1696 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1697
1698 let window = Expr::from(WindowFunction::new(
1699 WindowFunctionDefinition::WindowUDF(
1700 datafusion_functions_window::rank::rank_udwf(),
1701 ),
1702 vec![],
1703 ))
1704 .partition_by(vec![col("$a"), col("$b")])
1705 .order_by(vec![col("$c").sort(true, true)])
1706 .build()
1707 .unwrap();
1708
1709 let plan = LogicalPlanBuilder::from(table_scan)
1710 .window(vec![window])?
1711 .filter(col("$b").gt(lit(10i64)))?
1712 .build()?;
1713
1714 assert_optimized_plan_equal!(
1715 plan,
1716 @r"
1717 WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1718 TableScan: test, full_filters=[test.$b > Int64(10)]
1719 "
1720 )
1721 }
1722
1723 #[test]
1726 fn filter_move_complex_window() -> Result<()> {
1727 let table_scan = test_table_scan()?;
1728
1729 let window = Expr::from(WindowFunction::new(
1730 WindowFunctionDefinition::WindowUDF(
1731 datafusion_functions_window::rank::rank_udwf(),
1732 ),
1733 vec![],
1734 ))
1735 .partition_by(vec![col("a"), col("b")])
1736 .order_by(vec![col("c").sort(true, true)])
1737 .build()
1738 .unwrap();
1739
1740 let plan = LogicalPlanBuilder::from(table_scan)
1741 .window(vec![window])?
1742 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1743 .build()?;
1744
1745 assert_optimized_plan_equal!(
1746 plan,
1747 @r"
1748 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1749 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1750 "
1751 )
1752 }
1753
1754 #[test]
1756 fn filter_move_partial_window() -> Result<()> {
1757 let table_scan = test_table_scan()?;
1758
1759 let window = Expr::from(WindowFunction::new(
1760 WindowFunctionDefinition::WindowUDF(
1761 datafusion_functions_window::rank::rank_udwf(),
1762 ),
1763 vec![],
1764 ))
1765 .partition_by(vec![col("a")])
1766 .order_by(vec![col("c").sort(true, true)])
1767 .build()
1768 .unwrap();
1769
1770 let plan = LogicalPlanBuilder::from(table_scan)
1771 .window(vec![window])?
1772 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1773 .build()?;
1774
1775 assert_optimized_plan_equal!(
1776 plan,
1777 @r"
1778 Filter: test.b = Int64(1)
1779 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1780 TableScan: test, full_filters=[test.a > Int64(10)]
1781 "
1782 )
1783 }
1784
1785 #[test]
1788 fn filter_expression_keep_window() -> Result<()> {
1789 let table_scan = test_table_scan()?;
1790
1791 let window = Expr::from(WindowFunction::new(
1792 WindowFunctionDefinition::WindowUDF(
1793 datafusion_functions_window::rank::rank_udwf(),
1794 ),
1795 vec![],
1796 ))
1797 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1799 .build()
1800 .unwrap();
1801
1802 let plan = LogicalPlanBuilder::from(table_scan)
1803 .window(vec![window])?
1804 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1807 .build()?;
1808
1809 assert_optimized_plan_equal!(
1810 plan,
1811 @r"
1812 Filter: test.a + test.b > Int64(10)
1813 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1814 TableScan: test
1815 "
1816 )
1817 }
1818
1819 #[test]
1821 fn filter_order_keep_window() -> Result<()> {
1822 let table_scan = test_table_scan()?;
1823
1824 let window = Expr::from(WindowFunction::new(
1825 WindowFunctionDefinition::WindowUDF(
1826 datafusion_functions_window::rank::rank_udwf(),
1827 ),
1828 vec![],
1829 ))
1830 .partition_by(vec![col("a")])
1831 .order_by(vec![col("c").sort(true, true)])
1832 .build()
1833 .unwrap();
1834
1835 let plan = LogicalPlanBuilder::from(table_scan)
1836 .window(vec![window])?
1837 .filter(col("c").gt(lit(10i64)))?
1838 .build()?;
1839
1840 assert_optimized_plan_equal!(
1841 plan,
1842 @r"
1843 Filter: test.c > Int64(10)
1844 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1845 TableScan: test
1846 "
1847 )
1848 }
1849
1850 #[test]
1853 fn filter_multiple_windows_common_partitions() -> Result<()> {
1854 let table_scan = test_table_scan()?;
1855
1856 let window1 = Expr::from(WindowFunction::new(
1857 WindowFunctionDefinition::WindowUDF(
1858 datafusion_functions_window::rank::rank_udwf(),
1859 ),
1860 vec![],
1861 ))
1862 .partition_by(vec![col("a")])
1863 .order_by(vec![col("c").sort(true, true)])
1864 .build()
1865 .unwrap();
1866
1867 let window2 = Expr::from(WindowFunction::new(
1868 WindowFunctionDefinition::WindowUDF(
1869 datafusion_functions_window::rank::rank_udwf(),
1870 ),
1871 vec![],
1872 ))
1873 .partition_by(vec![col("b"), col("a")])
1874 .order_by(vec![col("c").sort(true, true)])
1875 .build()
1876 .unwrap();
1877
1878 let plan = LogicalPlanBuilder::from(table_scan)
1879 .window(vec![window1, window2])?
1880 .filter(col("a").gt(lit(10i64)))? .build()?;
1882
1883 assert_optimized_plan_equal!(
1884 plan,
1885 @r"
1886 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1887 TableScan: test, full_filters=[test.a > Int64(10)]
1888 "
1889 )
1890 }
1891
1892 #[test]
1895 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1896 let table_scan = test_table_scan()?;
1897
1898 let window1 = Expr::from(WindowFunction::new(
1899 WindowFunctionDefinition::WindowUDF(
1900 datafusion_functions_window::rank::rank_udwf(),
1901 ),
1902 vec![],
1903 ))
1904 .partition_by(vec![col("a")])
1905 .order_by(vec![col("c").sort(true, true)])
1906 .build()
1907 .unwrap();
1908
1909 let window2 = Expr::from(WindowFunction::new(
1910 WindowFunctionDefinition::WindowUDF(
1911 datafusion_functions_window::rank::rank_udwf(),
1912 ),
1913 vec![],
1914 ))
1915 .partition_by(vec![col("b"), col("a")])
1916 .order_by(vec![col("c").sort(true, true)])
1917 .build()
1918 .unwrap();
1919
1920 let plan = LogicalPlanBuilder::from(table_scan)
1921 .window(vec![window1, window2])?
1922 .filter(col("b").gt(lit(10i64)))? .build()?;
1924
1925 assert_optimized_plan_equal!(
1926 plan,
1927 @r"
1928 Filter: test.b > Int64(10)
1929 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1930 TableScan: test
1931 "
1932 )
1933 }
1934
1935 #[test]
1937 fn alias() -> Result<()> {
1938 let table_scan = test_table_scan()?;
1939 let plan = LogicalPlanBuilder::from(table_scan)
1940 .project(vec![col("a").alias("b"), col("c")])?
1941 .filter(col("b").eq(lit(1i64)))?
1942 .build()?;
1943 assert_optimized_plan_equal!(
1945 plan,
1946 @r"
1947 Projection: test.a AS b, test.c
1948 TableScan: test, full_filters=[test.a = Int64(1)]
1949 "
1950 )
1951 }
1952
1953 fn add(left: Expr, right: Expr) -> Expr {
1954 Expr::BinaryExpr(BinaryExpr::new(
1955 Box::new(left),
1956 Operator::Plus,
1957 Box::new(right),
1958 ))
1959 }
1960
1961 fn multiply(left: Expr, right: Expr) -> Expr {
1962 Expr::BinaryExpr(BinaryExpr::new(
1963 Box::new(left),
1964 Operator::Multiply,
1965 Box::new(right),
1966 ))
1967 }
1968
1969 #[test]
1971 fn complex_expression() -> Result<()> {
1972 let table_scan = test_table_scan()?;
1973 let plan = LogicalPlanBuilder::from(table_scan)
1974 .project(vec![
1975 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1976 col("c"),
1977 ])?
1978 .filter(col("b").eq(lit(1i64)))?
1979 .build()?;
1980
1981 assert_snapshot!(plan,
1983 @r"
1984 Filter: b = Int64(1)
1985 Projection: test.a * Int32(2) + test.c AS b, test.c
1986 TableScan: test
1987 ",
1988 );
1989 assert_optimized_plan_equal!(
1991 plan,
1992 @r"
1993 Projection: test.a * Int32(2) + test.c AS b, test.c
1994 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1995 "
1996 )
1997 }
1998
1999 #[test]
2001 fn complex_plan() -> Result<()> {
2002 let table_scan = test_table_scan()?;
2003 let plan = LogicalPlanBuilder::from(table_scan)
2004 .project(vec![
2005 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2006 col("c"),
2007 ])?
2008 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2010 .filter(col("a").eq(lit(1i64)))?
2011 .build()?;
2012
2013 assert_snapshot!(plan,
2015 @r"
2016 Filter: a = Int64(1)
2017 Projection: b * Int32(3) AS a, test.c
2018 Projection: test.a * Int32(2) + test.c AS b, test.c
2019 TableScan: test
2020 ",
2021 );
2022 assert_optimized_plan_equal!(
2024 plan,
2025 @r"
2026 Projection: b * Int32(3) AS a, test.c
2027 Projection: test.a * Int32(2) + test.c AS b, test.c
2028 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2029 "
2030 )
2031 }
2032
2033 #[derive(Debug, PartialEq, Eq, Hash)]
2034 struct NoopPlan {
2035 input: Vec<LogicalPlan>,
2036 schema: DFSchemaRef,
2037 }
2038
2039 impl PartialOrd for NoopPlan {
2041 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2042 self.input
2043 .partial_cmp(&other.input)
2044 .filter(|cmp| *cmp != Ordering::Equal || self == other)
2046 }
2047 }
2048
2049 impl UserDefinedLogicalNodeCore for NoopPlan {
2050 fn name(&self) -> &str {
2051 "NoopPlan"
2052 }
2053
2054 fn inputs(&self) -> Vec<&LogicalPlan> {
2055 self.input.iter().collect()
2056 }
2057
2058 fn schema(&self) -> &DFSchemaRef {
2059 &self.schema
2060 }
2061
2062 fn expressions(&self) -> Vec<Expr> {
2063 self.input
2064 .iter()
2065 .flat_map(|child| child.expressions())
2066 .collect()
2067 }
2068
2069 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2070 HashSet::from_iter(vec!["c".to_string()])
2071 }
2072
2073 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2074 write!(f, "NoopPlan")
2075 }
2076
2077 fn with_exprs_and_inputs(
2078 &self,
2079 _exprs: Vec<Expr>,
2080 inputs: Vec<LogicalPlan>,
2081 ) -> Result<Self> {
2082 Ok(Self {
2083 input: inputs,
2084 schema: Arc::clone(&self.schema),
2085 })
2086 }
2087
2088 fn supports_limit_pushdown(&self) -> bool {
2089 false }
2091 }
2092
2093 #[test]
2094 fn user_defined_plan() -> Result<()> {
2095 let table_scan = test_table_scan()?;
2096
2097 let custom_plan = LogicalPlan::Extension(Extension {
2098 node: Arc::new(NoopPlan {
2099 input: vec![table_scan.clone()],
2100 schema: Arc::clone(table_scan.schema()),
2101 }),
2102 });
2103 let plan = LogicalPlanBuilder::from(custom_plan)
2104 .filter(col("a").eq(lit(1i64)))?
2105 .build()?;
2106
2107 assert_optimized_plan_equal!(
2109 plan,
2110 @r"
2111 NoopPlan
2112 TableScan: test, full_filters=[test.a = Int64(1)]
2113 "
2114 )?;
2115
2116 let custom_plan = LogicalPlan::Extension(Extension {
2117 node: Arc::new(NoopPlan {
2118 input: vec![table_scan.clone()],
2119 schema: Arc::clone(table_scan.schema()),
2120 }),
2121 });
2122 let plan = LogicalPlanBuilder::from(custom_plan)
2123 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2124 .build()?;
2125
2126 assert_optimized_plan_equal!(
2128 plan,
2129 @r"
2130 Filter: test.c = Int64(2)
2131 NoopPlan
2132 TableScan: test, full_filters=[test.a = Int64(1)]
2133 "
2134 )?;
2135
2136 let custom_plan = LogicalPlan::Extension(Extension {
2137 node: Arc::new(NoopPlan {
2138 input: vec![table_scan.clone(), table_scan.clone()],
2139 schema: Arc::clone(table_scan.schema()),
2140 }),
2141 });
2142 let plan = LogicalPlanBuilder::from(custom_plan)
2143 .filter(col("a").eq(lit(1i64)))?
2144 .build()?;
2145
2146 assert_optimized_plan_equal!(
2148 plan,
2149 @r"
2150 NoopPlan
2151 TableScan: test, full_filters=[test.a = Int64(1)]
2152 TableScan: test, full_filters=[test.a = Int64(1)]
2153 "
2154 )?;
2155
2156 let custom_plan = LogicalPlan::Extension(Extension {
2157 node: Arc::new(NoopPlan {
2158 input: vec![table_scan.clone(), table_scan.clone()],
2159 schema: Arc::clone(table_scan.schema()),
2160 }),
2161 });
2162 let plan = LogicalPlanBuilder::from(custom_plan)
2163 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2164 .build()?;
2165
2166 assert_optimized_plan_equal!(
2168 plan,
2169 @r"
2170 Filter: test.c = Int64(2)
2171 NoopPlan
2172 TableScan: test, full_filters=[test.a = Int64(1)]
2173 TableScan: test, full_filters=[test.a = Int64(1)]
2174 "
2175 )
2176 }
2177
2178 #[test]
2181 fn multi_filter() -> Result<()> {
2182 let table_scan = test_table_scan()?;
2184 let plan = LogicalPlanBuilder::from(table_scan)
2185 .project(vec![col("a").alias("b"), col("c")])?
2186 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2187 .filter(col("b").gt(lit(10i64)))?
2188 .filter(col("sum(test.c)").gt(lit(10i64)))?
2189 .build()?;
2190
2191 assert_snapshot!(plan,
2193 @r"
2194 Filter: sum(test.c) > Int64(10)
2195 Filter: b > Int64(10)
2196 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2197 Projection: test.a AS b, test.c
2198 TableScan: test
2199 ",
2200 );
2201 assert_optimized_plan_equal!(
2203 plan,
2204 @r"
2205 Filter: sum(test.c) > Int64(10)
2206 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2207 Projection: test.a AS b, test.c
2208 TableScan: test, full_filters=[test.a > Int64(10)]
2209 "
2210 )
2211 }
2212
2213 #[test]
2216 fn split_filter() -> Result<()> {
2217 let table_scan = test_table_scan()?;
2219 let plan = LogicalPlanBuilder::from(table_scan)
2220 .project(vec![col("a").alias("b"), col("c")])?
2221 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2222 .filter(and(
2223 col("sum(test.c)").gt(lit(10i64)),
2224 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2225 ))?
2226 .build()?;
2227
2228 assert_snapshot!(plan,
2230 @r"
2231 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2232 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2233 Projection: test.a AS b, test.c
2234 TableScan: test
2235 ",
2236 );
2237 assert_optimized_plan_equal!(
2239 plan,
2240 @r"
2241 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2242 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2243 Projection: test.a AS b, test.c
2244 TableScan: test, full_filters=[test.a > Int64(10)]
2245 "
2246 )
2247 }
2248
2249 #[test]
2251 fn double_limit() -> Result<()> {
2252 let table_scan = test_table_scan()?;
2253 let plan = LogicalPlanBuilder::from(table_scan)
2254 .project(vec![col("a"), col("b")])?
2255 .limit(0, Some(20))?
2256 .limit(0, Some(10))?
2257 .project(vec![col("a"), col("b")])?
2258 .filter(col("a").eq(lit(1i64)))?
2259 .build()?;
2260 assert_optimized_plan_equal!(
2262 plan,
2263 @r"
2264 Projection: test.a, test.b
2265 Filter: test.a = Int64(1)
2266 Limit: skip=0, fetch=10
2267 Limit: skip=0, fetch=20
2268 Projection: test.a, test.b
2269 TableScan: test
2270 "
2271 )
2272 }
2273
2274 #[test]
2275 fn union_all() -> Result<()> {
2276 let table_scan = test_table_scan()?;
2277 let table_scan2 = test_table_scan_with_name("test2")?;
2278 let plan = LogicalPlanBuilder::from(table_scan)
2279 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2280 .filter(col("a").eq(lit(1i64)))?
2281 .build()?;
2282 assert_optimized_plan_equal!(
2284 plan,
2285 @r"
2286 Union
2287 TableScan: test, full_filters=[test.a = Int64(1)]
2288 TableScan: test2, full_filters=[test2.a = Int64(1)]
2289 "
2290 )
2291 }
2292
2293 #[test]
2294 fn union_all_on_projection() -> Result<()> {
2295 let table_scan = test_table_scan()?;
2296 let table = LogicalPlanBuilder::from(table_scan)
2297 .project(vec![col("a").alias("b")])?
2298 .alias("test2")?;
2299
2300 let plan = table
2301 .clone()
2302 .union(table.build()?)?
2303 .filter(col("b").eq(lit(1i64)))?
2304 .build()?;
2305
2306 assert_optimized_plan_equal!(
2308 plan,
2309 @r"
2310 Union
2311 SubqueryAlias: test2
2312 Projection: test.a AS b
2313 TableScan: test, full_filters=[test.a = Int64(1)]
2314 SubqueryAlias: test2
2315 Projection: test.a AS b
2316 TableScan: test, full_filters=[test.a = Int64(1)]
2317 "
2318 )
2319 }
2320
2321 #[test]
2322 fn test_union_different_schema() -> Result<()> {
2323 let left = LogicalPlanBuilder::from(test_table_scan()?)
2324 .project(vec![col("a"), col("b"), col("c")])?
2325 .build()?;
2326
2327 let schema = Schema::new(vec![
2328 Field::new("d", DataType::UInt32, false),
2329 Field::new("e", DataType::UInt32, false),
2330 Field::new("f", DataType::UInt32, false),
2331 ]);
2332 let right = table_scan(Some("test1"), &schema, None)?
2333 .project(vec![col("d"), col("e"), col("f")])?
2334 .build()?;
2335 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2336 let plan = LogicalPlanBuilder::from(left)
2337 .cross_join(right)?
2338 .project(vec![col("test.a"), col("test1.d")])?
2339 .filter(filter)?
2340 .build()?;
2341
2342 assert_optimized_plan_equal!(
2343 plan,
2344 @r"
2345 Projection: test.a, test1.d
2346 Cross Join:
2347 Projection: test.a, test.b, test.c
2348 TableScan: test, full_filters=[test.a = Int32(1)]
2349 Projection: test1.d, test1.e, test1.f
2350 TableScan: test1, full_filters=[test1.d > Int32(2)]
2351 "
2352 )
2353 }
2354
2355 #[test]
2356 fn test_project_same_name_different_qualifier() -> Result<()> {
2357 let table_scan = test_table_scan()?;
2358 let left = LogicalPlanBuilder::from(table_scan)
2359 .project(vec![col("a"), col("b"), col("c")])?
2360 .build()?;
2361 let right_table_scan = test_table_scan_with_name("test1")?;
2362 let right = LogicalPlanBuilder::from(right_table_scan)
2363 .project(vec![col("a"), col("b"), col("c")])?
2364 .build()?;
2365 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2366 let plan = LogicalPlanBuilder::from(left)
2367 .cross_join(right)?
2368 .project(vec![col("test.a"), col("test1.a")])?
2369 .filter(filter)?
2370 .build()?;
2371
2372 assert_optimized_plan_equal!(
2373 plan,
2374 @r"
2375 Projection: test.a, test1.a
2376 Cross Join:
2377 Projection: test.a, test.b, test.c
2378 TableScan: test, full_filters=[test.a = Int32(1)]
2379 Projection: test1.a, test1.b, test1.c
2380 TableScan: test1, full_filters=[test1.a > Int32(2)]
2381 "
2382 )
2383 }
2384
2385 #[test]
2387 fn filter_2_breaks_limits() -> Result<()> {
2388 let table_scan = test_table_scan()?;
2389 let plan = LogicalPlanBuilder::from(table_scan)
2390 .project(vec![col("a")])?
2391 .filter(col("a").lt_eq(lit(1i64)))?
2392 .limit(0, Some(1))?
2393 .project(vec![col("a")])?
2394 .filter(col("a").gt_eq(lit(1i64)))?
2395 .build()?;
2396 assert_snapshot!(plan,
2400 @r"
2401 Filter: test.a >= Int64(1)
2402 Projection: test.a
2403 Limit: skip=0, fetch=1
2404 Filter: test.a <= Int64(1)
2405 Projection: test.a
2406 TableScan: test
2407 ",
2408 );
2409 assert_optimized_plan_equal!(
2410 plan,
2411 @r"
2412 Projection: test.a
2413 Filter: test.a >= Int64(1)
2414 Limit: skip=0, fetch=1
2415 Projection: test.a
2416 TableScan: test, full_filters=[test.a <= Int64(1)]
2417 "
2418 )
2419 }
2420
2421 #[test]
2423 fn two_filters_on_same_depth() -> Result<()> {
2424 let table_scan = test_table_scan()?;
2425 let plan = LogicalPlanBuilder::from(table_scan)
2426 .limit(0, Some(1))?
2427 .filter(col("a").lt_eq(lit(1i64)))?
2428 .filter(col("a").gt_eq(lit(1i64)))?
2429 .project(vec![col("a")])?
2430 .build()?;
2431
2432 assert_snapshot!(plan,
2434 @r"
2435 Projection: test.a
2436 Filter: test.a >= Int64(1)
2437 Filter: test.a <= Int64(1)
2438 Limit: skip=0, fetch=1
2439 TableScan: test
2440 ",
2441 );
2442 assert_optimized_plan_equal!(
2443 plan,
2444 @r"
2445 Projection: test.a
2446 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2447 Limit: skip=0, fetch=1
2448 TableScan: test
2449 "
2450 )
2451 }
2452
2453 #[test]
2456 fn filters_user_defined_node() -> Result<()> {
2457 let table_scan = test_table_scan()?;
2458 let plan = LogicalPlanBuilder::from(table_scan)
2459 .filter(col("a").lt_eq(lit(1i64)))?
2460 .build()?;
2461
2462 let plan = user_defined::new(plan);
2463
2464 assert_snapshot!(plan,
2466 @r"
2467 TestUserDefined
2468 Filter: test.a <= Int64(1)
2469 TableScan: test
2470 ",
2471 );
2472 assert_optimized_plan_equal!(
2473 plan,
2474 @r"
2475 TestUserDefined
2476 TableScan: test, full_filters=[test.a <= Int64(1)]
2477 "
2478 )
2479 }
2480
2481 #[test]
2483 fn filter_on_join_on_common_independent() -> Result<()> {
2484 let table_scan = test_table_scan()?;
2485 let left = LogicalPlanBuilder::from(table_scan).build()?;
2486 let right_table_scan = test_table_scan_with_name("test2")?;
2487 let right = LogicalPlanBuilder::from(right_table_scan)
2488 .project(vec![col("a")])?
2489 .build()?;
2490 let plan = LogicalPlanBuilder::from(left)
2491 .join(
2492 right,
2493 JoinType::Inner,
2494 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2495 None,
2496 )?
2497 .filter(col("test.a").lt_eq(lit(1i64)))?
2498 .build()?;
2499
2500 assert_snapshot!(plan,
2502 @r"
2503 Filter: test.a <= Int64(1)
2504 Inner Join: test.a = test2.a
2505 TableScan: test
2506 Projection: test2.a
2507 TableScan: test2
2508 ",
2509 );
2510 assert_optimized_plan_equal!(
2512 plan,
2513 @r"
2514 Inner Join: test.a = test2.a
2515 TableScan: test, full_filters=[test.a <= Int64(1)]
2516 Projection: test2.a
2517 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2518 "
2519 )
2520 }
2521
2522 #[test]
2524 fn filter_using_join_on_common_independent() -> Result<()> {
2525 let table_scan = test_table_scan()?;
2526 let left = LogicalPlanBuilder::from(table_scan).build()?;
2527 let right_table_scan = test_table_scan_with_name("test2")?;
2528 let right = LogicalPlanBuilder::from(right_table_scan)
2529 .project(vec![col("a")])?
2530 .build()?;
2531 let plan = LogicalPlanBuilder::from(left)
2532 .join_using(
2533 right,
2534 JoinType::Inner,
2535 vec![Column::from_name("a".to_string())],
2536 )?
2537 .filter(col("a").lt_eq(lit(1i64)))?
2538 .build()?;
2539
2540 assert_snapshot!(plan,
2542 @r"
2543 Filter: test.a <= Int64(1)
2544 Inner Join: Using test.a = test2.a
2545 TableScan: test
2546 Projection: test2.a
2547 TableScan: test2
2548 ",
2549 );
2550 assert_optimized_plan_equal!(
2552 plan,
2553 @r"
2554 Inner Join: Using test.a = test2.a
2555 TableScan: test, full_filters=[test.a <= Int64(1)]
2556 Projection: test2.a
2557 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2558 "
2559 )
2560 }
2561
2562 #[test]
2564 fn filter_join_on_common_dependent() -> Result<()> {
2565 let table_scan = test_table_scan()?;
2566 let left = LogicalPlanBuilder::from(table_scan)
2567 .project(vec![col("a"), col("c")])?
2568 .build()?;
2569 let right_table_scan = test_table_scan_with_name("test2")?;
2570 let right = LogicalPlanBuilder::from(right_table_scan)
2571 .project(vec![col("a"), col("b")])?
2572 .build()?;
2573 let plan = LogicalPlanBuilder::from(left)
2574 .join(
2575 right,
2576 JoinType::Inner,
2577 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2578 None,
2579 )?
2580 .filter(col("c").lt_eq(col("b")))?
2581 .build()?;
2582
2583 assert_snapshot!(plan,
2585 @r"
2586 Filter: test.c <= test2.b
2587 Inner Join: test.a = test2.a
2588 Projection: test.a, test.c
2589 TableScan: test
2590 Projection: test2.a, test2.b
2591 TableScan: test2
2592 ",
2593 );
2594 assert_optimized_plan_equal!(
2596 plan,
2597 @r"
2598 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2599 Projection: test.a, test.c
2600 TableScan: test
2601 Projection: test2.a, test2.b
2602 TableScan: test2
2603 "
2604 )
2605 }
2606
2607 #[test]
2609 fn filter_join_on_one_side() -> Result<()> {
2610 let table_scan = test_table_scan()?;
2611 let left = LogicalPlanBuilder::from(table_scan)
2612 .project(vec![col("a"), col("b")])?
2613 .build()?;
2614 let table_scan_right = test_table_scan_with_name("test2")?;
2615 let right = LogicalPlanBuilder::from(table_scan_right)
2616 .project(vec![col("a"), col("c")])?
2617 .build()?;
2618
2619 let plan = LogicalPlanBuilder::from(left)
2620 .join(
2621 right,
2622 JoinType::Inner,
2623 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2624 None,
2625 )?
2626 .filter(col("b").lt_eq(lit(1i64)))?
2627 .build()?;
2628
2629 assert_snapshot!(plan,
2631 @r"
2632 Filter: test.b <= Int64(1)
2633 Inner Join: test.a = test2.a
2634 Projection: test.a, test.b
2635 TableScan: test
2636 Projection: test2.a, test2.c
2637 TableScan: test2
2638 ",
2639 );
2640 assert_optimized_plan_equal!(
2641 plan,
2642 @r"
2643 Inner Join: test.a = test2.a
2644 Projection: test.a, test.b
2645 TableScan: test, full_filters=[test.b <= Int64(1)]
2646 Projection: test2.a, test2.c
2647 TableScan: test2
2648 "
2649 )
2650 }
2651
2652 #[test]
2655 fn filter_using_left_join() -> Result<()> {
2656 let table_scan = test_table_scan()?;
2657 let left = LogicalPlanBuilder::from(table_scan).build()?;
2658 let right_table_scan = test_table_scan_with_name("test2")?;
2659 let right = LogicalPlanBuilder::from(right_table_scan)
2660 .project(vec![col("a")])?
2661 .build()?;
2662 let plan = LogicalPlanBuilder::from(left)
2663 .join_using(
2664 right,
2665 JoinType::Left,
2666 vec![Column::from_name("a".to_string())],
2667 )?
2668 .filter(col("test2.a").lt_eq(lit(1i64)))?
2669 .build()?;
2670
2671 assert_snapshot!(plan,
2673 @r"
2674 Filter: test2.a <= Int64(1)
2675 Left Join: Using test.a = test2.a
2676 TableScan: test
2677 Projection: test2.a
2678 TableScan: test2
2679 ",
2680 );
2681 assert_optimized_plan_equal!(
2683 plan,
2684 @r"
2685 Filter: test2.a <= Int64(1)
2686 Left Join: Using test.a = test2.a
2687 TableScan: test, full_filters=[test.a <= Int64(1)]
2688 Projection: test2.a
2689 TableScan: test2
2690 "
2691 )
2692 }
2693
2694 #[test]
2696 fn filter_using_right_join() -> Result<()> {
2697 let table_scan = test_table_scan()?;
2698 let left = LogicalPlanBuilder::from(table_scan).build()?;
2699 let right_table_scan = test_table_scan_with_name("test2")?;
2700 let right = LogicalPlanBuilder::from(right_table_scan)
2701 .project(vec![col("a")])?
2702 .build()?;
2703 let plan = LogicalPlanBuilder::from(left)
2704 .join_using(
2705 right,
2706 JoinType::Right,
2707 vec![Column::from_name("a".to_string())],
2708 )?
2709 .filter(col("test.a").lt_eq(lit(1i64)))?
2710 .build()?;
2711
2712 assert_snapshot!(plan,
2714 @r"
2715 Filter: test.a <= Int64(1)
2716 Right Join: Using test.a = test2.a
2717 TableScan: test
2718 Projection: test2.a
2719 TableScan: test2
2720 ",
2721 );
2722 assert_optimized_plan_equal!(
2724 plan,
2725 @r"
2726 Filter: test.a <= Int64(1)
2727 Right Join: Using test.a = test2.a
2728 TableScan: test
2729 Projection: test2.a
2730 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2731 "
2732 )
2733 }
2734
2735 #[test]
2737 fn filter_using_left_join_on_common() -> Result<()> {
2738 let table_scan = test_table_scan()?;
2739 let left = LogicalPlanBuilder::from(table_scan).build()?;
2740 let right_table_scan = test_table_scan_with_name("test2")?;
2741 let right = LogicalPlanBuilder::from(right_table_scan)
2742 .project(vec![col("a")])?
2743 .build()?;
2744 let plan = LogicalPlanBuilder::from(left)
2745 .join_using(
2746 right,
2747 JoinType::Left,
2748 vec![Column::from_name("a".to_string())],
2749 )?
2750 .filter(col("a").lt_eq(lit(1i64)))?
2751 .build()?;
2752
2753 assert_snapshot!(plan,
2755 @r"
2756 Filter: test.a <= Int64(1)
2757 Left Join: Using test.a = test2.a
2758 TableScan: test
2759 Projection: test2.a
2760 TableScan: test2
2761 ",
2762 );
2763 assert_optimized_plan_equal!(
2765 plan,
2766 @r"
2767 Left Join: Using test.a = test2.a
2768 TableScan: test, full_filters=[test.a <= Int64(1)]
2769 Projection: test2.a
2770 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2771 "
2772 )
2773 }
2774
2775 #[test]
2777 fn filter_using_right_join_on_common() -> Result<()> {
2778 let table_scan = test_table_scan()?;
2779 let left = LogicalPlanBuilder::from(table_scan).build()?;
2780 let right_table_scan = test_table_scan_with_name("test2")?;
2781 let right = LogicalPlanBuilder::from(right_table_scan)
2782 .project(vec![col("a")])?
2783 .build()?;
2784 let plan = LogicalPlanBuilder::from(left)
2785 .join_using(
2786 right,
2787 JoinType::Right,
2788 vec![Column::from_name("a".to_string())],
2789 )?
2790 .filter(col("test2.a").lt_eq(lit(1i64)))?
2791 .build()?;
2792
2793 assert_snapshot!(plan,
2795 @r"
2796 Filter: test2.a <= Int64(1)
2797 Right Join: Using test.a = test2.a
2798 TableScan: test
2799 Projection: test2.a
2800 TableScan: test2
2801 ",
2802 );
2803 assert_optimized_plan_equal!(
2805 plan,
2806 @r"
2807 Right Join: Using test.a = test2.a
2808 TableScan: test, full_filters=[test.a <= Int64(1)]
2809 Projection: test2.a
2810 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2811 "
2812 )
2813 }
2814
2815 #[test]
2817 fn join_on_with_filter() -> Result<()> {
2818 let table_scan = test_table_scan()?;
2819 let left = LogicalPlanBuilder::from(table_scan)
2820 .project(vec![col("a"), col("b"), col("c")])?
2821 .build()?;
2822 let right_table_scan = test_table_scan_with_name("test2")?;
2823 let right = LogicalPlanBuilder::from(right_table_scan)
2824 .project(vec![col("a"), col("b"), col("c")])?
2825 .build()?;
2826 let filter = col("test.c")
2827 .gt(lit(1u32))
2828 .and(col("test.b").lt(col("test2.b")))
2829 .and(col("test2.c").gt(lit(4u32)));
2830 let plan = LogicalPlanBuilder::from(left)
2831 .join(
2832 right,
2833 JoinType::Inner,
2834 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2835 Some(filter),
2836 )?
2837 .build()?;
2838
2839 assert_snapshot!(plan,
2841 @r"
2842 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2843 Projection: test.a, test.b, test.c
2844 TableScan: test
2845 Projection: test2.a, test2.b, test2.c
2846 TableScan: test2
2847 ",
2848 );
2849 assert_optimized_plan_equal!(
2850 plan,
2851 @r"
2852 Inner Join: test.a = test2.a Filter: test.b < test2.b
2853 Projection: test.a, test.b, test.c
2854 TableScan: test, full_filters=[test.c > UInt32(1)]
2855 Projection: test2.a, test2.b, test2.c
2856 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2857 "
2858 )
2859 }
2860
2861 #[test]
2863 fn join_filter_removed() -> Result<()> {
2864 let table_scan = test_table_scan()?;
2865 let left = LogicalPlanBuilder::from(table_scan)
2866 .project(vec![col("a"), col("b"), col("c")])?
2867 .build()?;
2868 let right_table_scan = test_table_scan_with_name("test2")?;
2869 let right = LogicalPlanBuilder::from(right_table_scan)
2870 .project(vec![col("a"), col("b"), col("c")])?
2871 .build()?;
2872 let filter = col("test.b")
2873 .gt(lit(1u32))
2874 .and(col("test2.c").gt(lit(4u32)));
2875 let plan = LogicalPlanBuilder::from(left)
2876 .join(
2877 right,
2878 JoinType::Inner,
2879 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2880 Some(filter),
2881 )?
2882 .build()?;
2883
2884 assert_snapshot!(plan,
2886 @r"
2887 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2888 Projection: test.a, test.b, test.c
2889 TableScan: test
2890 Projection: test2.a, test2.b, test2.c
2891 TableScan: test2
2892 ",
2893 );
2894 assert_optimized_plan_equal!(
2895 plan,
2896 @r"
2897 Inner Join: test.a = test2.a
2898 Projection: test.a, test.b, test.c
2899 TableScan: test, full_filters=[test.b > UInt32(1)]
2900 Projection: test2.a, test2.b, test2.c
2901 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2902 "
2903 )
2904 }
2905
2906 #[test]
2908 fn join_filter_on_common() -> Result<()> {
2909 let table_scan = test_table_scan()?;
2910 let left = LogicalPlanBuilder::from(table_scan)
2911 .project(vec![col("a")])?
2912 .build()?;
2913 let right_table_scan = test_table_scan_with_name("test2")?;
2914 let right = LogicalPlanBuilder::from(right_table_scan)
2915 .project(vec![col("b")])?
2916 .build()?;
2917 let filter = col("test.a").gt(lit(1u32));
2918 let plan = LogicalPlanBuilder::from(left)
2919 .join(
2920 right,
2921 JoinType::Inner,
2922 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2923 Some(filter),
2924 )?
2925 .build()?;
2926
2927 assert_snapshot!(plan,
2929 @r"
2930 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2931 Projection: test.a
2932 TableScan: test
2933 Projection: test2.b
2934 TableScan: test2
2935 ",
2936 );
2937 assert_optimized_plan_equal!(
2938 plan,
2939 @r"
2940 Inner Join: test.a = test2.b
2941 Projection: test.a
2942 TableScan: test, full_filters=[test.a > UInt32(1)]
2943 Projection: test2.b
2944 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2945 "
2946 )
2947 }
2948
2949 #[test]
2951 fn left_join_on_with_filter() -> Result<()> {
2952 let table_scan = test_table_scan()?;
2953 let left = LogicalPlanBuilder::from(table_scan)
2954 .project(vec![col("a"), col("b"), col("c")])?
2955 .build()?;
2956 let right_table_scan = test_table_scan_with_name("test2")?;
2957 let right = LogicalPlanBuilder::from(right_table_scan)
2958 .project(vec![col("a"), col("b"), col("c")])?
2959 .build()?;
2960 let filter = col("test.a")
2961 .gt(lit(1u32))
2962 .and(col("test.b").lt(col("test2.b")))
2963 .and(col("test2.c").gt(lit(4u32)));
2964 let plan = LogicalPlanBuilder::from(left)
2965 .join(
2966 right,
2967 JoinType::Left,
2968 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2969 Some(filter),
2970 )?
2971 .build()?;
2972
2973 assert_snapshot!(plan,
2975 @r"
2976 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2977 Projection: test.a, test.b, test.c
2978 TableScan: test
2979 Projection: test2.a, test2.b, test2.c
2980 TableScan: test2
2981 ",
2982 );
2983 assert_optimized_plan_equal!(
2984 plan,
2985 @r"
2986 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2987 Projection: test.a, test.b, test.c
2988 TableScan: test
2989 Projection: test2.a, test2.b, test2.c
2990 TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)]
2991 "
2992 )
2993 }
2994
2995 #[test]
2997 fn right_join_on_with_filter() -> Result<()> {
2998 let table_scan = test_table_scan()?;
2999 let left = LogicalPlanBuilder::from(table_scan)
3000 .project(vec![col("a"), col("b"), col("c")])?
3001 .build()?;
3002 let right_table_scan = test_table_scan_with_name("test2")?;
3003 let right = LogicalPlanBuilder::from(right_table_scan)
3004 .project(vec![col("a"), col("b"), col("c")])?
3005 .build()?;
3006 let filter = col("test.a")
3007 .gt(lit(1u32))
3008 .and(col("test.b").lt(col("test2.b")))
3009 .and(col("test2.c").gt(lit(4u32)));
3010 let plan = LogicalPlanBuilder::from(left)
3011 .join(
3012 right,
3013 JoinType::Right,
3014 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3015 Some(filter),
3016 )?
3017 .build()?;
3018
3019 assert_snapshot!(plan,
3021 @r"
3022 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3023 Projection: test.a, test.b, test.c
3024 TableScan: test
3025 Projection: test2.a, test2.b, test2.c
3026 TableScan: test2
3027 ",
3028 );
3029 assert_optimized_plan_equal!(
3030 plan,
3031 @r"
3032 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3033 Projection: test.a, test.b, test.c
3034 TableScan: test, full_filters=[test.a > UInt32(1)]
3035 Projection: test2.a, test2.b, test2.c
3036 TableScan: test2
3037 "
3038 )
3039 }
3040
3041 #[test]
3043 fn full_join_on_with_filter() -> Result<()> {
3044 let table_scan = test_table_scan()?;
3045 let left = LogicalPlanBuilder::from(table_scan)
3046 .project(vec![col("a"), col("b"), col("c")])?
3047 .build()?;
3048 let right_table_scan = test_table_scan_with_name("test2")?;
3049 let right = LogicalPlanBuilder::from(right_table_scan)
3050 .project(vec![col("a"), col("b"), col("c")])?
3051 .build()?;
3052 let filter = col("test.a")
3053 .gt(lit(1u32))
3054 .and(col("test.b").lt(col("test2.b")))
3055 .and(col("test2.c").gt(lit(4u32)));
3056 let plan = LogicalPlanBuilder::from(left)
3057 .join(
3058 right,
3059 JoinType::Full,
3060 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3061 Some(filter),
3062 )?
3063 .build()?;
3064
3065 assert_snapshot!(plan,
3067 @r"
3068 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3069 Projection: test.a, test.b, test.c
3070 TableScan: test
3071 Projection: test2.a, test2.b, test2.c
3072 TableScan: test2
3073 ",
3074 );
3075 assert_optimized_plan_equal!(
3076 plan,
3077 @r"
3078 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3079 Projection: test.a, test.b, test.c
3080 TableScan: test
3081 Projection: test2.a, test2.b, test2.c
3082 TableScan: test2
3083 "
3084 )
3085 }
3086
3087 struct PushDownProvider {
3088 pub filter_support: TableProviderFilterPushDown,
3089 }
3090
3091 #[async_trait]
3092 impl TableSource for PushDownProvider {
3093 fn schema(&self) -> SchemaRef {
3094 Arc::new(Schema::new(vec![
3095 Field::new("a", DataType::Int32, true),
3096 Field::new("b", DataType::Int32, true),
3097 ]))
3098 }
3099
3100 fn table_type(&self) -> TableType {
3101 TableType::Base
3102 }
3103
3104 fn supports_filters_pushdown(
3105 &self,
3106 filters: &[&Expr],
3107 ) -> Result<Vec<TableProviderFilterPushDown>> {
3108 Ok((0..filters.len())
3109 .map(|_| self.filter_support.clone())
3110 .collect())
3111 }
3112
3113 fn as_any(&self) -> &dyn Any {
3114 self
3115 }
3116 }
3117
3118 fn table_scan_with_pushdown_provider_builder(
3119 filter_support: TableProviderFilterPushDown,
3120 filters: Vec<Expr>,
3121 projection: Option<Vec<usize>>,
3122 ) -> Result<LogicalPlanBuilder> {
3123 let test_provider = PushDownProvider { filter_support };
3124
3125 let table_scan = LogicalPlan::TableScan(TableScan {
3126 table_name: "test".into(),
3127 filters,
3128 projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3129 projection,
3130 source: Arc::new(test_provider),
3131 fetch: None,
3132 });
3133
3134 Ok(LogicalPlanBuilder::from(table_scan))
3135 }
3136
3137 fn table_scan_with_pushdown_provider(
3138 filter_support: TableProviderFilterPushDown,
3139 ) -> Result<LogicalPlan> {
3140 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3141 .filter(col("a").eq(lit(1i64)))?
3142 .build()
3143 }
3144
3145 #[test]
3146 fn filter_with_table_provider_exact() -> Result<()> {
3147 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3148
3149 assert_optimized_plan_equal!(
3150 plan,
3151 @"TableScan: test, full_filters=[a = Int64(1)]"
3152 )
3153 }
3154
3155 #[test]
3156 fn filter_with_table_provider_inexact() -> Result<()> {
3157 let plan =
3158 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3159
3160 assert_optimized_plan_equal!(
3161 plan,
3162 @r"
3163 Filter: a = Int64(1)
3164 TableScan: test, partial_filters=[a = Int64(1)]
3165 "
3166 )
3167 }
3168
3169 #[test]
3170 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3171 let plan =
3172 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3173
3174 let optimized_plan = PushDownFilter::new()
3175 .rewrite(plan, &OptimizerContext::new())
3176 .expect("failed to optimize plan")
3177 .data;
3178
3179 assert_optimized_plan_equal!(
3182 optimized_plan,
3183 @r"
3184 Filter: a = Int64(1)
3185 TableScan: test, partial_filters=[a = Int64(1)]
3186 "
3187 )
3188 }
3189
3190 #[test]
3191 fn filter_with_table_provider_unsupported() -> Result<()> {
3192 let plan =
3193 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3194
3195 assert_optimized_plan_equal!(
3196 plan,
3197 @r"
3198 Filter: a = Int64(1)
3199 TableScan: test
3200 "
3201 )
3202 }
3203
3204 #[test]
3205 fn multi_combined_filter() -> Result<()> {
3206 let plan = table_scan_with_pushdown_provider_builder(
3207 TableProviderFilterPushDown::Inexact,
3208 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3209 Some(vec![0]),
3210 )?
3211 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3212 .project(vec![col("a"), col("b")])?
3213 .build()?;
3214
3215 assert_optimized_plan_equal!(
3216 plan,
3217 @r"
3218 Projection: a, b
3219 Filter: a = Int64(10) AND b > Int64(11)
3220 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3221 "
3222 )
3223 }
3224
3225 #[test]
3226 fn multi_combined_filter_exact() -> Result<()> {
3227 let plan = table_scan_with_pushdown_provider_builder(
3228 TableProviderFilterPushDown::Exact,
3229 vec![],
3230 Some(vec![0]),
3231 )?
3232 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3233 .project(vec![col("a"), col("b")])?
3234 .build()?;
3235
3236 assert_optimized_plan_equal!(
3237 plan,
3238 @r"
3239 Projection: a, b
3240 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3241 "
3242 )
3243 }
3244
3245 #[test]
3246 fn test_filter_with_alias() -> Result<()> {
3247 let table_scan = test_table_scan()?;
3251 let plan = LogicalPlanBuilder::from(table_scan)
3252 .project(vec![col("a").alias("b"), col("c")])?
3253 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3254 .build()?;
3255
3256 assert_snapshot!(plan,
3258 @r"
3259 Filter: b > Int64(10) AND test.c > Int64(10)
3260 Projection: test.a AS b, test.c
3261 TableScan: test
3262 ",
3263 );
3264 assert_optimized_plan_equal!(
3266 plan,
3267 @r"
3268 Projection: test.a AS b, test.c
3269 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3270 "
3271 )
3272 }
3273
3274 #[test]
3275 fn test_filter_with_alias_2() -> Result<()> {
3276 let table_scan = test_table_scan()?;
3280 let plan = LogicalPlanBuilder::from(table_scan)
3281 .project(vec![col("a").alias("b"), col("c")])?
3282 .project(vec![col("b"), col("c")])?
3283 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3284 .build()?;
3285
3286 assert_snapshot!(plan,
3288 @r"
3289 Filter: b > Int64(10) AND test.c > Int64(10)
3290 Projection: b, test.c
3291 Projection: test.a AS b, test.c
3292 TableScan: test
3293 ",
3294 );
3295 assert_optimized_plan_equal!(
3297 plan,
3298 @r"
3299 Projection: b, test.c
3300 Projection: test.a AS b, test.c
3301 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3302 "
3303 )
3304 }
3305
3306 #[test]
3307 fn test_filter_with_multi_alias() -> Result<()> {
3308 let table_scan = test_table_scan()?;
3309 let plan = LogicalPlanBuilder::from(table_scan)
3310 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3311 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3312 .build()?;
3313
3314 assert_snapshot!(plan,
3316 @r"
3317 Filter: b > Int64(10) AND d > Int64(10)
3318 Projection: test.a AS b, test.c AS d
3319 TableScan: test
3320 ",
3321 );
3322 assert_optimized_plan_equal!(
3324 plan,
3325 @r"
3326 Projection: test.a AS b, test.c AS d
3327 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3328 "
3329 )
3330 }
3331
3332 #[test]
3334 fn join_filter_with_alias() -> Result<()> {
3335 let table_scan = test_table_scan()?;
3336 let left = LogicalPlanBuilder::from(table_scan)
3337 .project(vec![col("a").alias("c")])?
3338 .build()?;
3339 let right_table_scan = test_table_scan_with_name("test2")?;
3340 let right = LogicalPlanBuilder::from(right_table_scan)
3341 .project(vec![col("b").alias("d")])?
3342 .build()?;
3343 let filter = col("c").gt(lit(1u32));
3344 let plan = LogicalPlanBuilder::from(left)
3345 .join(
3346 right,
3347 JoinType::Inner,
3348 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3349 Some(filter),
3350 )?
3351 .build()?;
3352
3353 assert_snapshot!(plan,
3354 @r"
3355 Inner Join: c = d Filter: c > UInt32(1)
3356 Projection: test.a AS c
3357 TableScan: test
3358 Projection: test2.b AS d
3359 TableScan: test2
3360 ",
3361 );
3362 assert_optimized_plan_equal!(
3364 plan,
3365 @r"
3366 Inner Join: c = d
3367 Projection: test.a AS c
3368 TableScan: test, full_filters=[test.a > UInt32(1)]
3369 Projection: test2.b AS d
3370 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3371 "
3372 )
3373 }
3374
3375 #[test]
3376 fn test_in_filter_with_alias() -> Result<()> {
3377 let table_scan = test_table_scan()?;
3381 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3382 let plan = LogicalPlanBuilder::from(table_scan)
3383 .project(vec![col("a").alias("b"), col("c")])?
3384 .filter(in_list(col("b"), filter_value, false))?
3385 .build()?;
3386
3387 assert_snapshot!(plan,
3389 @r"
3390 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3391 Projection: test.a AS b, test.c
3392 TableScan: test
3393 ",
3394 );
3395 assert_optimized_plan_equal!(
3397 plan,
3398 @r"
3399 Projection: test.a AS b, test.c
3400 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3401 "
3402 )
3403 }
3404
3405 #[test]
3406 fn test_in_filter_with_alias_2() -> Result<()> {
3407 let table_scan = test_table_scan()?;
3411 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3412 let plan = LogicalPlanBuilder::from(table_scan)
3413 .project(vec![col("a").alias("b"), col("c")])?
3414 .project(vec![col("b"), col("c")])?
3415 .filter(in_list(col("b"), filter_value, false))?
3416 .build()?;
3417
3418 assert_snapshot!(plan,
3420 @r"
3421 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3422 Projection: b, test.c
3423 Projection: test.a AS b, test.c
3424 TableScan: test
3425 ",
3426 );
3427 assert_optimized_plan_equal!(
3429 plan,
3430 @r"
3431 Projection: b, test.c
3432 Projection: test.a AS b, test.c
3433 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3434 "
3435 )
3436 }
3437
3438 #[test]
3439 fn test_in_subquery_with_alias() -> Result<()> {
3440 let table_scan = test_table_scan()?;
3443 let table_scan_sq = test_table_scan_with_name("sq")?;
3444 let subplan = Arc::new(
3445 LogicalPlanBuilder::from(table_scan_sq)
3446 .project(vec![col("c")])?
3447 .build()?,
3448 );
3449 let plan = LogicalPlanBuilder::from(table_scan)
3450 .project(vec![col("a").alias("b"), col("c")])?
3451 .filter(in_subquery(col("b"), subplan))?
3452 .build()?;
3453
3454 assert_snapshot!(plan,
3456 @r"
3457 Filter: b IN (<subquery>)
3458 Subquery:
3459 Projection: sq.c
3460 TableScan: sq
3461 Projection: test.a AS b, test.c
3462 TableScan: test
3463 ",
3464 );
3465 assert_optimized_plan_equal!(
3467 plan,
3468 @r"
3469 Projection: test.a AS b, test.c
3470 TableScan: test, full_filters=[test.a IN (<subquery>)]
3471 Subquery:
3472 Projection: sq.c
3473 TableScan: sq
3474 "
3475 )
3476 }
3477
3478 #[test]
3479 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3480 let plan = LogicalPlanBuilder::empty(true)
3482 .project(vec![lit(0i64).alias("a")])?
3483 .alias("b")?
3484 .project(vec![col("b.a")])?
3485 .alias("b")?
3486 .filter(col("b.a").eq(lit(1i64)))?
3487 .project(vec![col("b.a")])?
3488 .build()?;
3489
3490 assert_snapshot!(plan,
3491 @r"
3492 Projection: b.a
3493 Filter: b.a = Int64(1)
3494 SubqueryAlias: b
3495 Projection: b.a
3496 SubqueryAlias: b
3497 Projection: Int64(0) AS a
3498 EmptyRelation: rows=1
3499 ",
3500 );
3501 assert_optimized_plan_equal!(
3504 plan,
3505 @r"
3506 Projection: b.a
3507 SubqueryAlias: b
3508 Projection: b.a
3509 SubqueryAlias: b
3510 Projection: Int64(0) AS a
3511 Filter: Int64(0) = Int64(1)
3512 EmptyRelation: rows=1
3513 "
3514 )
3515 }
3516
3517 #[test]
3518 fn test_crossjoin_with_or_clause() -> Result<()> {
3519 let table_scan = test_table_scan()?;
3521 let left = LogicalPlanBuilder::from(table_scan)
3522 .project(vec![col("a"), col("b"), col("c")])?
3523 .build()?;
3524 let right_table_scan = test_table_scan_with_name("test1")?;
3525 let right = LogicalPlanBuilder::from(right_table_scan)
3526 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3527 .build()?;
3528 let filter = or(
3529 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3530 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3531 );
3532 let plan = LogicalPlanBuilder::from(left)
3533 .cross_join(right)?
3534 .filter(filter)?
3535 .build()?;
3536
3537 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3538 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3539 Projection: test.a, test.b, test.c
3540 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3541 Projection: test1.a AS d, test1.a AS e
3542 TableScan: test1
3543 ")?;
3544
3545 let optimized_plan = PushDownFilter::new()
3548 .rewrite(plan, &OptimizerContext::new())
3549 .expect("failed to optimize plan")
3550 .data;
3551 assert_optimized_plan_equal!(
3552 optimized_plan,
3553 @r"
3554 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3555 Projection: test.a, test.b, test.c
3556 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3557 Projection: test1.a AS d, test1.a AS e
3558 TableScan: test1
3559 "
3560 )
3561 }
3562
3563 #[test]
3564 fn left_semi_join() -> Result<()> {
3565 let left = test_table_scan_with_name("test1")?;
3566 let right_table_scan = test_table_scan_with_name("test2")?;
3567 let right = LogicalPlanBuilder::from(right_table_scan)
3568 .project(vec![col("a"), col("b")])?
3569 .build()?;
3570 let plan = LogicalPlanBuilder::from(left)
3571 .join(
3572 right,
3573 JoinType::LeftSemi,
3574 (
3575 vec![Column::from_qualified_name("test1.a")],
3576 vec![Column::from_qualified_name("test2.a")],
3577 ),
3578 None,
3579 )?
3580 .filter(col("test2.a").lt_eq(lit(1i64)))?
3581 .build()?;
3582
3583 assert_snapshot!(plan,
3585 @r"
3586 Filter: test2.a <= Int64(1)
3587 LeftSemi Join: test1.a = test2.a
3588 TableScan: test1
3589 Projection: test2.a, test2.b
3590 TableScan: test2
3591 ",
3592 );
3593 assert_optimized_plan_equal!(
3595 plan,
3596 @r"
3597 Filter: test2.a <= Int64(1)
3598 LeftSemi Join: test1.a = test2.a
3599 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3600 Projection: test2.a, test2.b
3601 TableScan: test2
3602 "
3603 )
3604 }
3605
3606 #[test]
3607 fn left_semi_join_with_filters() -> Result<()> {
3608 let left = test_table_scan_with_name("test1")?;
3609 let right_table_scan = test_table_scan_with_name("test2")?;
3610 let right = LogicalPlanBuilder::from(right_table_scan)
3611 .project(vec![col("a"), col("b")])?
3612 .build()?;
3613 let plan = LogicalPlanBuilder::from(left)
3614 .join(
3615 right,
3616 JoinType::LeftSemi,
3617 (
3618 vec![Column::from_qualified_name("test1.a")],
3619 vec![Column::from_qualified_name("test2.a")],
3620 ),
3621 Some(
3622 col("test1.b")
3623 .gt(lit(1u32))
3624 .and(col("test2.b").gt(lit(2u32))),
3625 ),
3626 )?
3627 .build()?;
3628
3629 assert_snapshot!(plan,
3631 @r"
3632 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3633 TableScan: test1
3634 Projection: test2.a, test2.b
3635 TableScan: test2
3636 ",
3637 );
3638 assert_optimized_plan_equal!(
3640 plan,
3641 @r"
3642 LeftSemi Join: test1.a = test2.a
3643 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3644 Projection: test2.a, test2.b
3645 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3646 "
3647 )
3648 }
3649
3650 #[test]
3651 fn right_semi_join() -> Result<()> {
3652 let left = test_table_scan_with_name("test1")?;
3653 let right_table_scan = test_table_scan_with_name("test2")?;
3654 let right = LogicalPlanBuilder::from(right_table_scan)
3655 .project(vec![col("a"), col("b")])?
3656 .build()?;
3657 let plan = LogicalPlanBuilder::from(left)
3658 .join(
3659 right,
3660 JoinType::RightSemi,
3661 (
3662 vec![Column::from_qualified_name("test1.a")],
3663 vec![Column::from_qualified_name("test2.a")],
3664 ),
3665 None,
3666 )?
3667 .filter(col("test1.a").lt_eq(lit(1i64)))?
3668 .build()?;
3669
3670 assert_snapshot!(plan,
3672 @r"
3673 Filter: test1.a <= Int64(1)
3674 RightSemi Join: test1.a = test2.a
3675 TableScan: test1
3676 Projection: test2.a, test2.b
3677 TableScan: test2
3678 ",
3679 );
3680 assert_optimized_plan_equal!(
3682 plan,
3683 @r"
3684 Filter: test1.a <= Int64(1)
3685 RightSemi Join: test1.a = test2.a
3686 TableScan: test1
3687 Projection: test2.a, test2.b
3688 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3689 "
3690 )
3691 }
3692
3693 #[test]
3694 fn right_semi_join_with_filters() -> Result<()> {
3695 let left = test_table_scan_with_name("test1")?;
3696 let right_table_scan = test_table_scan_with_name("test2")?;
3697 let right = LogicalPlanBuilder::from(right_table_scan)
3698 .project(vec![col("a"), col("b")])?
3699 .build()?;
3700 let plan = LogicalPlanBuilder::from(left)
3701 .join(
3702 right,
3703 JoinType::RightSemi,
3704 (
3705 vec![Column::from_qualified_name("test1.a")],
3706 vec![Column::from_qualified_name("test2.a")],
3707 ),
3708 Some(
3709 col("test1.b")
3710 .gt(lit(1u32))
3711 .and(col("test2.b").gt(lit(2u32))),
3712 ),
3713 )?
3714 .build()?;
3715
3716 assert_snapshot!(plan,
3718 @r"
3719 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3720 TableScan: test1
3721 Projection: test2.a, test2.b
3722 TableScan: test2
3723 ",
3724 );
3725 assert_optimized_plan_equal!(
3727 plan,
3728 @r"
3729 RightSemi Join: test1.a = test2.a
3730 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3731 Projection: test2.a, test2.b
3732 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3733 "
3734 )
3735 }
3736
3737 #[test]
3738 fn left_anti_join() -> Result<()> {
3739 let table_scan = test_table_scan_with_name("test1")?;
3740 let left = LogicalPlanBuilder::from(table_scan)
3741 .project(vec![col("a"), col("b")])?
3742 .build()?;
3743 let right_table_scan = test_table_scan_with_name("test2")?;
3744 let right = LogicalPlanBuilder::from(right_table_scan)
3745 .project(vec![col("a"), col("b")])?
3746 .build()?;
3747 let plan = LogicalPlanBuilder::from(left)
3748 .join(
3749 right,
3750 JoinType::LeftAnti,
3751 (
3752 vec![Column::from_qualified_name("test1.a")],
3753 vec![Column::from_qualified_name("test2.a")],
3754 ),
3755 None,
3756 )?
3757 .filter(col("test2.a").gt(lit(2u32)))?
3758 .build()?;
3759
3760 assert_snapshot!(plan,
3762 @r"
3763 Filter: test2.a > UInt32(2)
3764 LeftAnti Join: test1.a = test2.a
3765 Projection: test1.a, test1.b
3766 TableScan: test1
3767 Projection: test2.a, test2.b
3768 TableScan: test2
3769 ",
3770 );
3771 assert_optimized_plan_equal!(
3773 plan,
3774 @r"
3775 Filter: test2.a > UInt32(2)
3776 LeftAnti Join: test1.a = test2.a
3777 Projection: test1.a, test1.b
3778 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3779 Projection: test2.a, test2.b
3780 TableScan: test2
3781 "
3782 )
3783 }
3784
3785 #[test]
3786 fn left_anti_join_with_filters() -> Result<()> {
3787 let table_scan = test_table_scan_with_name("test1")?;
3788 let left = LogicalPlanBuilder::from(table_scan)
3789 .project(vec![col("a"), col("b")])?
3790 .build()?;
3791 let right_table_scan = test_table_scan_with_name("test2")?;
3792 let right = LogicalPlanBuilder::from(right_table_scan)
3793 .project(vec![col("a"), col("b")])?
3794 .build()?;
3795 let plan = LogicalPlanBuilder::from(left)
3796 .join(
3797 right,
3798 JoinType::LeftAnti,
3799 (
3800 vec![Column::from_qualified_name("test1.a")],
3801 vec![Column::from_qualified_name("test2.a")],
3802 ),
3803 Some(
3804 col("test1.b")
3805 .gt(lit(1u32))
3806 .and(col("test2.b").gt(lit(2u32))),
3807 ),
3808 )?
3809 .build()?;
3810
3811 assert_snapshot!(plan,
3813 @r"
3814 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3815 Projection: test1.a, test1.b
3816 TableScan: test1
3817 Projection: test2.a, test2.b
3818 TableScan: test2
3819 ",
3820 );
3821 assert_optimized_plan_equal!(
3823 plan,
3824 @r"
3825 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3826 Projection: test1.a, test1.b
3827 TableScan: test1
3828 Projection: test2.a, test2.b
3829 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3830 "
3831 )
3832 }
3833
3834 #[test]
3835 fn right_anti_join() -> Result<()> {
3836 let table_scan = test_table_scan_with_name("test1")?;
3837 let left = LogicalPlanBuilder::from(table_scan)
3838 .project(vec![col("a"), col("b")])?
3839 .build()?;
3840 let right_table_scan = test_table_scan_with_name("test2")?;
3841 let right = LogicalPlanBuilder::from(right_table_scan)
3842 .project(vec![col("a"), col("b")])?
3843 .build()?;
3844 let plan = LogicalPlanBuilder::from(left)
3845 .join(
3846 right,
3847 JoinType::RightAnti,
3848 (
3849 vec![Column::from_qualified_name("test1.a")],
3850 vec![Column::from_qualified_name("test2.a")],
3851 ),
3852 None,
3853 )?
3854 .filter(col("test1.a").gt(lit(2u32)))?
3855 .build()?;
3856
3857 assert_snapshot!(plan,
3859 @r"
3860 Filter: test1.a > UInt32(2)
3861 RightAnti Join: test1.a = test2.a
3862 Projection: test1.a, test1.b
3863 TableScan: test1
3864 Projection: test2.a, test2.b
3865 TableScan: test2
3866 ",
3867 );
3868 assert_optimized_plan_equal!(
3870 plan,
3871 @r"
3872 Filter: test1.a > UInt32(2)
3873 RightAnti Join: test1.a = test2.a
3874 Projection: test1.a, test1.b
3875 TableScan: test1
3876 Projection: test2.a, test2.b
3877 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3878 "
3879 )
3880 }
3881
3882 #[test]
3883 fn right_anti_join_with_filters() -> Result<()> {
3884 let table_scan = test_table_scan_with_name("test1")?;
3885 let left = LogicalPlanBuilder::from(table_scan)
3886 .project(vec![col("a"), col("b")])?
3887 .build()?;
3888 let right_table_scan = test_table_scan_with_name("test2")?;
3889 let right = LogicalPlanBuilder::from(right_table_scan)
3890 .project(vec![col("a"), col("b")])?
3891 .build()?;
3892 let plan = LogicalPlanBuilder::from(left)
3893 .join(
3894 right,
3895 JoinType::RightAnti,
3896 (
3897 vec![Column::from_qualified_name("test1.a")],
3898 vec![Column::from_qualified_name("test2.a")],
3899 ),
3900 Some(
3901 col("test1.b")
3902 .gt(lit(1u32))
3903 .and(col("test2.b").gt(lit(2u32))),
3904 ),
3905 )?
3906 .build()?;
3907
3908 assert_snapshot!(plan,
3910 @r"
3911 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3912 Projection: test1.a, test1.b
3913 TableScan: test1
3914 Projection: test2.a, test2.b
3915 TableScan: test2
3916 ",
3917 );
3918 assert_optimized_plan_equal!(
3920 plan,
3921 @r"
3922 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3923 Projection: test1.a, test1.b
3924 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3925 Projection: test2.a, test2.b
3926 TableScan: test2
3927 "
3928 )
3929 }
3930
3931 #[derive(Debug, PartialEq, Eq, Hash)]
3932 struct TestScalarUDF {
3933 signature: Signature,
3934 }
3935
3936 impl ScalarUDFImpl for TestScalarUDF {
3937 fn as_any(&self) -> &dyn Any {
3938 self
3939 }
3940 fn name(&self) -> &str {
3941 "TestScalarUDF"
3942 }
3943
3944 fn signature(&self) -> &Signature {
3945 &self.signature
3946 }
3947
3948 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3949 Ok(DataType::Int32)
3950 }
3951
3952 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3953 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3954 }
3955 }
3956
3957 #[test]
3958 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3959 let table_scan = test_table_scan_with_name("test1")?;
3961 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3962 signature: Signature::exact(vec![], Volatility::Volatile),
3963 });
3964 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3965
3966 let plan = LogicalPlanBuilder::from(table_scan)
3967 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3968 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3969 .alias("t")?
3970 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3971 .project(vec![col("t.a"), col("t.r")])?
3972 .build()?;
3973
3974 assert_snapshot!(plan,
3975 @r"
3976 Projection: t.a, t.r
3977 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3978 SubqueryAlias: t
3979 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3980 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3981 TableScan: test1
3982 ",
3983 );
3984 assert_optimized_plan_equal!(
3985 plan,
3986 @r"
3987 Projection: t.a, t.r
3988 SubqueryAlias: t
3989 Filter: r > Float64(0.5)
3990 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3991 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3992 TableScan: test1, full_filters=[test1.a > Int32(5)]
3993 "
3994 )
3995 }
3996
3997 #[test]
3998 fn test_push_down_volatile_function_in_join() -> Result<()> {
3999 let table_scan = test_table_scan_with_name("test1")?;
4001 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4002 signature: Signature::exact(vec![], Volatility::Volatile),
4003 });
4004 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4005 let left = LogicalPlanBuilder::from(table_scan).build()?;
4006 let right_table_scan = test_table_scan_with_name("test2")?;
4007 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4008 let plan = LogicalPlanBuilder::from(left)
4009 .join(
4010 right,
4011 JoinType::Inner,
4012 (
4013 vec![Column::from_qualified_name("test1.a")],
4014 vec![Column::from_qualified_name("test2.a")],
4015 ),
4016 None,
4017 )?
4018 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4019 .alias("t")?
4020 .filter(col("t.r").gt(lit(0.8)))?
4021 .project(vec![col("t.a"), col("t.r")])?
4022 .build()?;
4023
4024 assert_snapshot!(plan,
4025 @r"
4026 Projection: t.a, t.r
4027 Filter: t.r > Float64(0.8)
4028 SubqueryAlias: t
4029 Projection: test1.a AS a, TestScalarUDF() AS r
4030 Inner Join: test1.a = test2.a
4031 TableScan: test1
4032 TableScan: test2
4033 ",
4034 );
4035 assert_optimized_plan_equal!(
4036 plan,
4037 @r"
4038 Projection: t.a, t.r
4039 SubqueryAlias: t
4040 Filter: r > Float64(0.8)
4041 Projection: test1.a AS a, TestScalarUDF() AS r
4042 Inner Join: test1.a = test2.a
4043 TableScan: test1
4044 TableScan: test2
4045 "
4046 )
4047 }
4048
4049 #[test]
4050 fn test_push_down_volatile_table_scan() -> Result<()> {
4051 let table_scan = test_table_scan()?;
4053 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4054 signature: Signature::exact(vec![], Volatility::Volatile),
4055 });
4056 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4057 let plan = LogicalPlanBuilder::from(table_scan)
4058 .project(vec![col("a"), col("b")])?
4059 .filter(expr.gt(lit(0.1)))?
4060 .build()?;
4061
4062 assert_snapshot!(plan,
4063 @r"
4064 Filter: TestScalarUDF() > Float64(0.1)
4065 Projection: test.a, test.b
4066 TableScan: test
4067 ",
4068 );
4069 assert_optimized_plan_equal!(
4070 plan,
4071 @r"
4072 Projection: test.a, test.b
4073 Filter: TestScalarUDF() > Float64(0.1)
4074 TableScan: test
4075 "
4076 )
4077 }
4078
4079 #[test]
4080 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4081 let table_scan = test_table_scan()?;
4083 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4084 signature: Signature::exact(vec![], Volatility::Volatile),
4085 });
4086 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4087 let plan = LogicalPlanBuilder::from(table_scan)
4088 .project(vec![col("a"), col("b")])?
4089 .filter(
4090 expr.gt(lit(0.1))
4091 .and(col("t.a").gt(lit(5)))
4092 .and(col("t.b").gt(lit(10))),
4093 )?
4094 .build()?;
4095
4096 assert_snapshot!(plan,
4097 @r"
4098 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4099 Projection: test.a, test.b
4100 TableScan: test
4101 ",
4102 );
4103 assert_optimized_plan_equal!(
4104 plan,
4105 @r"
4106 Projection: test.a, test.b
4107 Filter: TestScalarUDF() > Float64(0.1)
4108 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4109 "
4110 )
4111 }
4112
4113 #[test]
4114 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4115 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4117 signature: Signature::exact(vec![], Volatility::Volatile),
4118 });
4119 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4120 let plan = table_scan_with_pushdown_provider_builder(
4121 TableProviderFilterPushDown::Unsupported,
4122 vec![],
4123 None,
4124 )?
4125 .project(vec![col("a"), col("b")])?
4126 .filter(
4127 expr.gt(lit(0.1))
4128 .and(col("t.a").gt(lit(5)))
4129 .and(col("t.b").gt(lit(10))),
4130 )?
4131 .build()?;
4132
4133 assert_snapshot!(plan,
4134 @r"
4135 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4136 Projection: a, b
4137 TableScan: test
4138 ",
4139 );
4140 assert_optimized_plan_equal!(
4141 plan,
4142 @r"
4143 Projection: a, b
4144 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4145 TableScan: test
4146 "
4147 )
4148 }
4149
4150 #[test]
4151 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4152 #[derive(Debug, Hash, Eq, PartialEq)]
4154 struct TestUserNode {
4155 schema: DFSchemaRef,
4156 }
4157
4158 impl PartialOrd for TestUserNode {
4159 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4160 None
4161 }
4162 }
4163
4164 impl TestUserNode {
4165 fn new() -> Self {
4166 let schema = Arc::new(
4167 DFSchema::new_with_metadata(
4168 vec![(None, Field::new("a", DataType::Int64, false).into())],
4169 Default::default(),
4170 )
4171 .unwrap(),
4172 );
4173
4174 Self { schema }
4175 }
4176 }
4177
4178 impl UserDefinedLogicalNodeCore for TestUserNode {
4179 fn name(&self) -> &str {
4180 "test_node"
4181 }
4182
4183 fn inputs(&self) -> Vec<&LogicalPlan> {
4184 vec![]
4185 }
4186
4187 fn schema(&self) -> &DFSchemaRef {
4188 &self.schema
4189 }
4190
4191 fn expressions(&self) -> Vec<Expr> {
4192 vec![]
4193 }
4194
4195 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4196 write!(f, "TestUserNode")
4197 }
4198
4199 fn with_exprs_and_inputs(
4200 &self,
4201 exprs: Vec<Expr>,
4202 inputs: Vec<LogicalPlan>,
4203 ) -> Result<Self> {
4204 assert!(exprs.is_empty());
4205 assert!(inputs.is_empty());
4206 Ok(Self {
4207 schema: Arc::clone(&self.schema),
4208 })
4209 }
4210 }
4211
4212 let node = LogicalPlan::Extension(Extension {
4214 node: Arc::new(TestUserNode::new()),
4215 });
4216
4217 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4218
4219 assert_snapshot!(plan,
4221 @r"
4222 Filter: Boolean(false)
4223 TestUserNode
4224 ",
4225 );
4226 assert_optimized_plan_equal!(
4228 plan,
4229 @r"
4230 Filter: Boolean(false)
4231 TestUserNode
4232 "
4233 )
4234 }
4235
4236 #[test]
4242 fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> {
4243 let table_scan = test_table_scan()?;
4244
4245 let proj = LogicalPlanBuilder::from(table_scan)
4247 .project(vec![
4248 leaf_udf_expr(col("a")).alias("val"),
4249 col("b"),
4250 col("c"),
4251 ])?
4252 .build()?;
4253
4254 let plan = LogicalPlanBuilder::from(proj)
4256 .filter(col("val").gt(lit(150i64)))?
4257 .build()?;
4258
4259 assert_optimized_plan_equal!(
4261 plan,
4262 @r"
4263 Filter: val > Int64(150)
4264 Projection: leaf_udf(test.a) AS val, test.b, test.c
4265 TableScan: test
4266 "
4267 )
4268 }
4269
4270 #[test]
4272 fn filter_mixed_predicates_partial_push() -> Result<()> {
4273 let table_scan = test_table_scan()?;
4274
4275 let proj = LogicalPlanBuilder::from(table_scan)
4277 .project(vec![
4278 leaf_udf_expr(col("a")).alias("val"),
4279 col("b"),
4280 col("c"),
4281 ])?
4282 .build()?;
4283
4284 let plan = LogicalPlanBuilder::from(proj)
4286 .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))?
4287 .build()?;
4288
4289 assert_optimized_plan_equal!(
4291 plan,
4292 @r"
4293 Filter: val > Int64(150)
4294 Projection: leaf_udf(test.a) AS val, test.b, test.c
4295 TableScan: test, full_filters=[test.b > Int64(5)]
4296 "
4297 )
4298 }
4299}