1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31 internal_err, plan_err, qualified_name, Column, DFSchema, Result,
32};
33use datafusion_expr::expr::WindowFunction;
34use datafusion_expr::expr_rewriter::replace_col;
35use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
36use datafusion_expr::utils::{
37 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
38};
39use datafusion_expr::{
40 and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
41};
42
43use crate::optimizer::ApplyOrder;
44use crate::simplify_expressions::simplify_predicates;
45use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
46use crate::{OptimizerConfig, OptimizerRule};
47
48#[derive(Default, Debug)]
136pub struct PushDownFilter {}
137
138pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
163 match join_type {
164 JoinType::Inner => (true, true),
165 JoinType::Left => (true, false),
166 JoinType::Right => (false, true),
167 JoinType::Full => (false, false),
168 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
171 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
174 }
175}
176
177pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
187 match join_type {
188 JoinType::Inner => (true, true),
189 JoinType::Left => (false, true),
190 JoinType::Right => (true, false),
191 JoinType::Full => (false, false),
192 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
193 JoinType::LeftAnti => (false, true),
194 JoinType::RightAnti => (true, false),
195 JoinType::LeftMark => (false, true),
196 JoinType::RightMark => (true, false),
197 }
198}
199
200#[derive(Debug)]
203struct ColumnChecker<'a> {
204 left_schema: &'a DFSchema,
206 left_columns: Option<HashSet<Column>>,
208 right_schema: &'a DFSchema,
210 right_columns: Option<HashSet<Column>>,
212}
213
214impl<'a> ColumnChecker<'a> {
215 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
216 Self {
217 left_schema,
218 left_columns: None,
219 right_schema,
220 right_columns: None,
221 }
222 }
223
224 fn is_left_only(&mut self, predicate: &Expr) -> bool {
226 if self.left_columns.is_none() {
227 self.left_columns = Some(schema_columns(self.left_schema));
228 }
229 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
230 }
231
232 fn is_right_only(&mut self, predicate: &Expr) -> bool {
234 if self.right_columns.is_none() {
235 self.right_columns = Some(schema_columns(self.right_schema));
236 }
237 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
238 }
239}
240
241fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
243 schema
244 .iter()
245 .flat_map(|(qualifier, field)| {
246 [
247 Column::new(qualifier.cloned(), field.name()),
248 Column::new_unqualified(field.name()),
250 ]
251 })
252 .collect::<HashSet<_>>()
253}
254
255fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
257 let mut is_evaluate = true;
258 predicate.apply(|expr| match expr {
259 Expr::Column(_)
260 | Expr::Literal(_, _)
261 | Expr::Placeholder(_)
262 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
263 Expr::Exists { .. }
264 | Expr::InSubquery(_)
265 | Expr::ScalarSubquery(_)
266 | Expr::OuterReferenceColumn(_, _)
267 | Expr::Unnest(_) => {
268 is_evaluate = false;
269 Ok(TreeNodeRecursion::Stop)
270 }
271 Expr::Alias(_)
272 | Expr::BinaryExpr(_)
273 | Expr::Like(_)
274 | Expr::SimilarTo(_)
275 | Expr::Not(_)
276 | Expr::IsNotNull(_)
277 | Expr::IsNull(_)
278 | Expr::IsTrue(_)
279 | Expr::IsFalse(_)
280 | Expr::IsUnknown(_)
281 | Expr::IsNotTrue(_)
282 | Expr::IsNotFalse(_)
283 | Expr::IsNotUnknown(_)
284 | Expr::Negative(_)
285 | Expr::Between(_)
286 | Expr::Case(_)
287 | Expr::Cast(_)
288 | Expr::TryCast(_)
289 | Expr::InList { .. }
290 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
291 #[expect(deprecated)]
293 Expr::AggregateFunction(_)
294 | Expr::WindowFunction(_)
295 | Expr::Wildcard { .. }
296 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
297 })?;
298 Ok(is_evaluate)
299}
300
301fn extract_or_clauses_for_join<'a>(
335 filters: &'a [Expr],
336 schema: &'a DFSchema,
337) -> impl Iterator<Item = Expr> + 'a {
338 let schema_columns = schema_columns(schema);
339
340 filters.iter().filter_map(move |expr| {
342 if let Expr::BinaryExpr(BinaryExpr {
343 left,
344 op: Operator::Or,
345 right,
346 }) = expr
347 {
348 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
349 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
350
351 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
353 return Some(or(left_expr, right_expr));
354 }
355 }
356 None
357 })
358}
359
360fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
372 let mut predicate = None;
373
374 match expr {
375 Expr::BinaryExpr(BinaryExpr {
376 left: l_expr,
377 op: Operator::Or,
378 right: r_expr,
379 }) => {
380 let l_expr = extract_or_clause(l_expr, schema_columns);
381 let r_expr = extract_or_clause(r_expr, schema_columns);
382
383 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
384 predicate = Some(or(l_expr, r_expr));
385 }
386 }
387 Expr::BinaryExpr(BinaryExpr {
388 left: l_expr,
389 op: Operator::And,
390 right: r_expr,
391 }) => {
392 let l_expr = extract_or_clause(l_expr, schema_columns);
393 let r_expr = extract_or_clause(r_expr, schema_columns);
394
395 match (l_expr, r_expr) {
396 (Some(l_expr), Some(r_expr)) => {
397 predicate = Some(and(l_expr, r_expr));
398 }
399 (Some(l_expr), None) => {
400 predicate = Some(l_expr);
401 }
402 (None, Some(r_expr)) => {
403 predicate = Some(r_expr);
404 }
405 (None, None) => {
406 predicate = None;
407 }
408 }
409 }
410 _ => {
411 if has_all_column_refs(expr, schema_columns) {
412 predicate = Some(expr.clone());
413 }
414 }
415 }
416
417 predicate
418}
419
420fn push_down_all_join(
422 predicates: Vec<Expr>,
423 inferred_join_predicates: Vec<Expr>,
424 mut join: Join,
425 on_filter: Vec<Expr>,
426) -> Result<Transformed<LogicalPlan>> {
427 let is_inner_join = join.join_type == JoinType::Inner;
428 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
430
431 let left_schema = join.left.schema();
436 let right_schema = join.right.schema();
437 let mut left_push = vec![];
438 let mut right_push = vec![];
439 let mut keep_predicates = vec![];
440 let mut join_conditions = vec![];
441 let mut checker = ColumnChecker::new(left_schema, right_schema);
442 for predicate in predicates {
443 if left_preserved && checker.is_left_only(&predicate) {
444 left_push.push(predicate);
445 } else if right_preserved && checker.is_right_only(&predicate) {
446 right_push.push(predicate);
447 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
448 join_conditions.push(predicate);
451 } else {
452 keep_predicates.push(predicate);
453 }
454 }
455
456 for predicate in inferred_join_predicates {
458 if left_preserved && checker.is_left_only(&predicate) {
459 left_push.push(predicate);
460 } else if right_preserved && checker.is_right_only(&predicate) {
461 right_push.push(predicate);
462 }
463 }
464
465 let mut on_filter_join_conditions = vec![];
466 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
467
468 if !on_filter.is_empty() {
469 for on in on_filter {
470 if on_left_preserved && checker.is_left_only(&on) {
471 left_push.push(on)
472 } else if on_right_preserved && checker.is_right_only(&on) {
473 right_push.push(on)
474 } else {
475 on_filter_join_conditions.push(on)
476 }
477 }
478 }
479
480 if left_preserved {
483 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
484 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
485 }
486 if right_preserved {
487 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
488 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
489 }
490
491 if on_left_preserved {
494 left_push.extend(extract_or_clauses_for_join(
495 &on_filter_join_conditions,
496 left_schema,
497 ));
498 }
499 if on_right_preserved {
500 right_push.extend(extract_or_clauses_for_join(
501 &on_filter_join_conditions,
502 right_schema,
503 ));
504 }
505
506 if let Some(predicate) = conjunction(left_push) {
507 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
508 }
509 if let Some(predicate) = conjunction(right_push) {
510 join.right =
511 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
512 }
513
514 join_conditions.extend(on_filter_join_conditions);
516 join.filter = conjunction(join_conditions);
517
518 let plan = LogicalPlan::Join(join);
520 let plan = if let Some(predicate) = conjunction(keep_predicates) {
521 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
522 } else {
523 plan
524 };
525 Ok(Transformed::yes(plan))
526}
527
528fn push_down_join(
529 join: Join,
530 parent_predicate: Option<&Expr>,
531) -> Result<Transformed<LogicalPlan>> {
532 let predicates = parent_predicate
534 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
535
536 let on_filters = join
538 .filter
539 .as_ref()
540 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
541
542 let inferred_join_predicates =
544 infer_join_predicates(&join, &predicates, &on_filters)?;
545
546 if on_filters.is_empty()
547 && predicates.is_empty()
548 && inferred_join_predicates.is_empty()
549 {
550 return Ok(Transformed::no(LogicalPlan::Join(join)));
551 }
552
553 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
554}
555
556fn infer_join_predicates(
567 join: &Join,
568 predicates: &[Expr],
569 on_filters: &[Expr],
570) -> Result<Vec<Expr>> {
571 let join_col_keys = join
573 .on
574 .iter()
575 .filter_map(|(l, r)| {
576 let left_col = l.try_as_col()?;
577 let right_col = r.try_as_col()?;
578 Some((left_col, right_col))
579 })
580 .collect::<Vec<_>>();
581
582 let join_type = join.join_type;
583
584 let mut inferred_predicates = InferredPredicates::new(join_type);
585
586 infer_join_predicates_from_predicates(
587 &join_col_keys,
588 predicates,
589 &mut inferred_predicates,
590 )?;
591
592 infer_join_predicates_from_on_filters(
593 &join_col_keys,
594 join_type,
595 on_filters,
596 &mut inferred_predicates,
597 )?;
598
599 Ok(inferred_predicates.predicates)
600}
601
602struct InferredPredicates {
611 predicates: Vec<Expr>,
612 is_inner_join: bool,
613}
614
615impl InferredPredicates {
616 fn new(join_type: JoinType) -> Self {
617 Self {
618 predicates: vec![],
619 is_inner_join: matches!(join_type, JoinType::Inner),
620 }
621 }
622
623 fn try_build_predicate(
624 &mut self,
625 predicate: Expr,
626 replace_map: &HashMap<&Column, &Column>,
627 ) -> Result<()> {
628 if self.is_inner_join
629 || matches!(
630 is_restrict_null_predicate(
631 predicate.clone(),
632 replace_map.keys().cloned()
633 ),
634 Ok(true)
635 )
636 {
637 self.predicates.push(replace_col(predicate, replace_map)?);
638 }
639
640 Ok(())
641 }
642}
643
644fn infer_join_predicates_from_predicates(
654 join_col_keys: &[(&Column, &Column)],
655 predicates: &[Expr],
656 inferred_predicates: &mut InferredPredicates,
657) -> Result<()> {
658 infer_join_predicates_impl::<true, true>(
659 join_col_keys,
660 predicates,
661 inferred_predicates,
662 )
663}
664
665fn infer_join_predicates_from_on_filters(
678 join_col_keys: &[(&Column, &Column)],
679 join_type: JoinType,
680 on_filters: &[Expr],
681 inferred_predicates: &mut InferredPredicates,
682) -> Result<()> {
683 match join_type {
684 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
685 JoinType::Inner => infer_join_predicates_impl::<true, true>(
686 join_col_keys,
687 on_filters,
688 inferred_predicates,
689 ),
690 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
691 infer_join_predicates_impl::<true, false>(
692 join_col_keys,
693 on_filters,
694 inferred_predicates,
695 )
696 }
697 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
698 infer_join_predicates_impl::<false, true>(
699 join_col_keys,
700 on_filters,
701 inferred_predicates,
702 )
703 }
704 }
705}
706
707fn infer_join_predicates_impl<
724 const ENABLE_LEFT_TO_RIGHT: bool,
725 const ENABLE_RIGHT_TO_LEFT: bool,
726>(
727 join_col_keys: &[(&Column, &Column)],
728 input_predicates: &[Expr],
729 inferred_predicates: &mut InferredPredicates,
730) -> Result<()> {
731 for predicate in input_predicates {
732 let mut join_cols_to_replace = HashMap::new();
733
734 for &col in &predicate.column_refs() {
735 for (l, r) in join_col_keys.iter() {
736 if ENABLE_LEFT_TO_RIGHT && col == *l {
737 join_cols_to_replace.insert(col, *r);
738 break;
739 }
740 if ENABLE_RIGHT_TO_LEFT && col == *r {
741 join_cols_to_replace.insert(col, *l);
742 break;
743 }
744 }
745 }
746 if join_cols_to_replace.is_empty() {
747 continue;
748 }
749
750 inferred_predicates
751 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
752 }
753 Ok(())
754}
755
756impl OptimizerRule for PushDownFilter {
757 fn name(&self) -> &str {
758 "push_down_filter"
759 }
760
761 fn apply_order(&self) -> Option<ApplyOrder> {
762 Some(ApplyOrder::TopDown)
763 }
764
765 fn supports_rewrite(&self) -> bool {
766 true
767 }
768
769 fn rewrite(
770 &self,
771 plan: LogicalPlan,
772 _config: &dyn OptimizerConfig,
773 ) -> Result<Transformed<LogicalPlan>> {
774 if let LogicalPlan::Join(join) = plan {
775 return push_down_join(join, None);
776 };
777
778 let plan_schema = Arc::clone(plan.schema());
779
780 let LogicalPlan::Filter(mut filter) = plan else {
781 return Ok(Transformed::no(plan));
782 };
783
784 let predicate = split_conjunction_owned(filter.predicate.clone());
785 let old_predicate_len = predicate.len();
786 let new_predicates = simplify_predicates(predicate)?;
787 if old_predicate_len != new_predicates.len() {
788 let Some(new_predicate) = conjunction(new_predicates) else {
789 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
792 };
793 filter.predicate = new_predicate;
794 }
795
796 match Arc::unwrap_or_clone(filter.input) {
797 LogicalPlan::Filter(child_filter) => {
798 let parents_predicates = split_conjunction_owned(filter.predicate);
799
800 let child_predicates = split_conjunction_owned(child_filter.predicate);
802 let new_predicates = parents_predicates
803 .into_iter()
804 .chain(child_predicates)
805 .collect::<IndexSet<_>>()
807 .into_iter()
808 .collect::<Vec<_>>();
809
810 let Some(new_predicate) = conjunction(new_predicates) else {
811 return plan_err!("at least one expression exists");
812 };
813 let new_filter = LogicalPlan::Filter(Filter::try_new(
814 new_predicate,
815 child_filter.input,
816 )?);
817 #[allow(clippy::used_underscore_binding)]
818 self.rewrite(new_filter, _config)
819 }
820 LogicalPlan::Repartition(repartition) => {
821 let new_filter =
822 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
823 .map(LogicalPlan::Filter)?;
824 insert_below(LogicalPlan::Repartition(repartition), new_filter)
825 }
826 LogicalPlan::Distinct(distinct) => {
827 let new_filter =
828 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
829 .map(LogicalPlan::Filter)?;
830 insert_below(LogicalPlan::Distinct(distinct), new_filter)
831 }
832 LogicalPlan::Sort(sort) => {
833 let new_filter =
834 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
835 .map(LogicalPlan::Filter)?;
836 insert_below(LogicalPlan::Sort(sort), new_filter)
837 }
838 LogicalPlan::SubqueryAlias(subquery_alias) => {
839 let mut replace_map = HashMap::new();
840 for (i, (qualifier, field)) in
841 subquery_alias.input.schema().iter().enumerate()
842 {
843 let (sub_qualifier, sub_field) =
844 subquery_alias.schema.qualified_field(i);
845 replace_map.insert(
846 qualified_name(sub_qualifier, sub_field.name()),
847 Expr::Column(Column::new(qualifier.cloned(), field.name())),
848 );
849 }
850 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
851
852 let new_filter = LogicalPlan::Filter(Filter::try_new(
853 new_predicate,
854 Arc::clone(&subquery_alias.input),
855 )?);
856 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
857 }
858 LogicalPlan::Projection(projection) => {
859 let predicates = split_conjunction_owned(filter.predicate.clone());
860 let (new_projection, keep_predicate) =
861 rewrite_projection(predicates, projection)?;
862 if new_projection.transformed {
863 match keep_predicate {
864 None => Ok(new_projection),
865 Some(keep_predicate) => new_projection.map_data(|child_plan| {
866 Filter::try_new(keep_predicate, Arc::new(child_plan))
867 .map(LogicalPlan::Filter)
868 }),
869 }
870 } else {
871 filter.input = Arc::new(new_projection.data);
872 Ok(Transformed::no(LogicalPlan::Filter(filter)))
873 }
874 }
875 LogicalPlan::Unnest(mut unnest) => {
876 let predicates = split_conjunction_owned(filter.predicate.clone());
877 let mut non_unnest_predicates = vec![];
878 let mut unnest_predicates = vec![];
879 let mut unnest_struct_columns = vec![];
880
881 for idx in &unnest.struct_type_columns {
882 let (sub_qualifier, field) =
883 unnest.input.schema().qualified_field(*idx);
884 let field_name = field.name().clone();
885
886 if let DataType::Struct(children) = field.data_type() {
887 for child in children {
888 let child_name = child.name().clone();
889 unnest_struct_columns.push(Column::new(
890 sub_qualifier.cloned(),
891 format!("{field_name}.{child_name}"),
892 ));
893 }
894 }
895 }
896
897 for predicate in predicates {
898 let mut accum: HashSet<Column> = HashSet::new();
900 expr_to_columns(&predicate, &mut accum)?;
901
902 let contains_list_columns =
903 unnest.list_type_columns.iter().any(|(_, unnest_list)| {
904 accum.contains(&unnest_list.output_column)
905 });
906 let contains_struct_columns =
907 unnest_struct_columns.iter().any(|c| accum.contains(c));
908
909 if contains_list_columns || contains_struct_columns {
910 unnest_predicates.push(predicate);
911 } else {
912 non_unnest_predicates.push(predicate);
913 }
914 }
915
916 if non_unnest_predicates.is_empty() {
919 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
920 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
921 }
922
923 let unnest_input = std::mem::take(&mut unnest.input);
932
933 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
934 conjunction(non_unnest_predicates).unwrap(), unnest_input,
936 )?);
937
938 let unnest_plan =
942 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
943
944 match conjunction(unnest_predicates) {
945 None => Ok(unnest_plan),
946 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
947 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
948 ))),
949 }
950 }
951 LogicalPlan::Union(ref union) => {
952 let mut inputs = Vec::with_capacity(union.inputs.len());
953 for input in &union.inputs {
954 let mut replace_map = HashMap::new();
955 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
956 let (union_qualifier, union_field) =
957 union.schema.qualified_field(i);
958 replace_map.insert(
959 qualified_name(union_qualifier, union_field.name()),
960 Expr::Column(Column::new(qualifier.cloned(), field.name())),
961 );
962 }
963
964 let push_predicate =
965 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
966 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
967 push_predicate,
968 Arc::clone(input),
969 )?)))
970 }
971 Ok(Transformed::yes(LogicalPlan::Union(Union {
972 inputs,
973 schema: Arc::clone(&plan_schema),
974 })))
975 }
976 LogicalPlan::Aggregate(agg) => {
977 let group_expr_columns = agg
979 .group_expr
980 .iter()
981 .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
982 .collect::<Result<HashSet<_>>>()?;
983
984 let predicates = split_conjunction_owned(filter.predicate);
985
986 let mut keep_predicates = vec![];
987 let mut push_predicates = vec![];
988 for expr in predicates {
989 let cols = expr.column_refs();
990 if cols.iter().all(|c| group_expr_columns.contains(c)) {
991 push_predicates.push(expr);
992 } else {
993 keep_predicates.push(expr);
994 }
995 }
996
997 let mut replace_map = HashMap::new();
1001 for expr in &agg.group_expr {
1002 replace_map.insert(expr.schema_name().to_string(), expr.clone());
1003 }
1004 let replaced_push_predicates = push_predicates
1005 .into_iter()
1006 .map(|expr| replace_cols_by_name(expr, &replace_map))
1007 .collect::<Result<Vec<_>>>()?;
1008
1009 let agg_input = Arc::clone(&agg.input);
1010 Transformed::yes(LogicalPlan::Aggregate(agg))
1011 .transform_data(|new_plan| {
1012 if let Some(predicate) = conjunction(replaced_push_predicates) {
1014 let new_filter = make_filter(predicate, agg_input)?;
1015 insert_below(new_plan, new_filter)
1016 } else {
1017 Ok(Transformed::no(new_plan))
1018 }
1019 })?
1020 .map_data(|child_plan| {
1021 if let Some(predicate) = conjunction(keep_predicates) {
1024 make_filter(predicate, Arc::new(child_plan))
1025 } else {
1026 Ok(child_plan)
1027 }
1028 })
1029 }
1030 LogicalPlan::Window(window) => {
1041 let extract_partition_keys = |func: &WindowFunction| {
1047 func.params
1048 .partition_by
1049 .iter()
1050 .map(|c| Column::from_qualified_name(c.schema_name().to_string()))
1051 .collect::<HashSet<_>>()
1052 };
1053 let potential_partition_keys = window
1054 .window_expr
1055 .iter()
1056 .map(|e| {
1057 match e {
1058 Expr::WindowFunction(window_func) => {
1059 extract_partition_keys(window_func)
1060 }
1061 Expr::Alias(alias) => {
1062 if let Expr::WindowFunction(window_func) =
1063 alias.expr.as_ref()
1064 {
1065 extract_partition_keys(window_func)
1066 } else {
1067 unreachable!()
1069 }
1070 }
1071 _ => {
1072 unreachable!()
1074 }
1075 }
1076 })
1077 .reduce(|a, b| &a & &b)
1080 .unwrap_or_default();
1081
1082 let predicates = split_conjunction_owned(filter.predicate);
1083 let mut keep_predicates = vec![];
1084 let mut push_predicates = vec![];
1085 for expr in predicates {
1086 let cols = expr.column_refs();
1087 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1088 push_predicates.push(expr);
1089 } else {
1090 keep_predicates.push(expr);
1091 }
1092 }
1093
1094 let window_input = Arc::clone(&window.input);
1103 Transformed::yes(LogicalPlan::Window(window))
1104 .transform_data(|new_plan| {
1105 if let Some(predicate) = conjunction(push_predicates) {
1107 let new_filter = make_filter(predicate, window_input)?;
1108 insert_below(new_plan, new_filter)
1109 } else {
1110 Ok(Transformed::no(new_plan))
1111 }
1112 })?
1113 .map_data(|child_plan| {
1114 if let Some(predicate) = conjunction(keep_predicates) {
1117 make_filter(predicate, Arc::new(child_plan))
1118 } else {
1119 Ok(child_plan)
1120 }
1121 })
1122 }
1123 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1124 LogicalPlan::TableScan(scan) => {
1125 let filter_predicates = split_conjunction(&filter.predicate);
1126
1127 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1128 filter_predicates
1129 .into_iter()
1130 .partition(|pred| pred.is_volatile());
1131
1132 let supported_filters = scan
1134 .source
1135 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1136 if non_volatile_filters.len() != supported_filters.len() {
1137 return internal_err!(
1138 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1139 supported_filters.len(),
1140 non_volatile_filters.len());
1141 }
1142
1143 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1145
1146 let new_scan_filters = zip
1147 .clone()
1148 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1149 .map(|(pred, _)| pred);
1150
1151 let new_scan_filters: Vec<Expr> = scan
1153 .filters
1154 .iter()
1155 .chain(new_scan_filters)
1156 .unique()
1157 .cloned()
1158 .collect();
1159
1160 let new_predicate: Vec<Expr> = zip
1162 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1163 .map(|(pred, _)| pred)
1164 .chain(volatile_filters)
1165 .cloned()
1166 .collect();
1167
1168 let new_scan = LogicalPlan::TableScan(TableScan {
1169 filters: new_scan_filters,
1170 ..scan
1171 });
1172
1173 Transformed::yes(new_scan).transform_data(|new_scan| {
1174 if let Some(predicate) = conjunction(new_predicate) {
1175 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1176 } else {
1177 Ok(Transformed::no(new_scan))
1178 }
1179 })
1180 }
1181 LogicalPlan::Extension(extension_plan) => {
1182 if extension_plan.node.inputs().is_empty() {
1185 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1186 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1187 }
1188 let prevent_cols =
1189 extension_plan.node.prevent_predicate_push_down_columns();
1190
1191 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1195 .iter()
1196 .map(|expr| {
1197 let cols = expr.column_refs();
1198 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1199 Ok(false) } else {
1201 Ok(true) }
1203 })
1204 .collect::<Result<Vec<_>>>()?;
1205
1206 if predicate_push_or_keep.iter().all(|&x| !x) {
1208 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1209 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1210 }
1211
1212 let mut keep_predicates = vec![];
1214 let mut push_predicates = vec![];
1215 for (push, expr) in predicate_push_or_keep
1216 .into_iter()
1217 .zip(split_conjunction_owned(filter.predicate).into_iter())
1218 {
1219 if !push {
1220 keep_predicates.push(expr);
1221 } else {
1222 push_predicates.push(expr);
1223 }
1224 }
1225
1226 let new_children = match conjunction(push_predicates) {
1227 Some(predicate) => extension_plan
1228 .node
1229 .inputs()
1230 .into_iter()
1231 .map(|child| {
1232 Ok(LogicalPlan::Filter(Filter::try_new(
1233 predicate.clone(),
1234 Arc::new(child.clone()),
1235 )?))
1236 })
1237 .collect::<Result<Vec<_>>>()?,
1238 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1239 };
1240 let child_plan = LogicalPlan::Extension(extension_plan);
1242 let new_extension =
1243 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1244
1245 let new_plan = match conjunction(keep_predicates) {
1246 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1247 predicate,
1248 Arc::new(new_extension),
1249 )?),
1250 None => new_extension,
1251 };
1252 Ok(Transformed::yes(new_plan))
1253 }
1254 child => {
1255 filter.input = Arc::new(child);
1256 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1257 }
1258 }
1259 }
1260}
1261
1262fn rewrite_projection(
1290 predicates: Vec<Expr>,
1291 mut projection: Projection,
1292) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1293 let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1297 .schema
1298 .iter()
1299 .zip(projection.expr.iter())
1300 .map(|((qualifier, field), expr)| {
1301 let expr = expr.clone().unalias();
1303
1304 (qualified_name(qualifier, field.name()), expr)
1305 })
1306 .partition(|(_, value)| value.is_volatile());
1307
1308 let mut push_predicates = vec![];
1309 let mut keep_predicates = vec![];
1310 for expr in predicates {
1311 if contain(&expr, &volatile_map) {
1312 keep_predicates.push(expr);
1313 } else {
1314 push_predicates.push(expr);
1315 }
1316 }
1317
1318 match conjunction(push_predicates) {
1319 Some(expr) => {
1320 let new_filter = LogicalPlan::Filter(Filter::try_new(
1323 replace_cols_by_name(expr, &non_volatile_map)?,
1324 std::mem::take(&mut projection.input),
1325 )?);
1326
1327 projection.input = Arc::new(new_filter);
1328
1329 Ok((
1330 Transformed::yes(LogicalPlan::Projection(projection)),
1331 conjunction(keep_predicates),
1332 ))
1333 }
1334 None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1335 }
1336}
1337
1338pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1340 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1341}
1342
1343fn insert_below(
1357 plan: LogicalPlan,
1358 new_child: LogicalPlan,
1359) -> Result<Transformed<LogicalPlan>> {
1360 let mut new_child = Some(new_child);
1361 let transformed_plan = plan.map_children(|_child| {
1362 if let Some(new_child) = new_child.take() {
1363 Ok(Transformed::yes(new_child))
1364 } else {
1365 internal_err!("node had more than one input")
1367 }
1368 })?;
1369
1370 if new_child.is_some() {
1372 return internal_err!("node had no inputs");
1373 }
1374
1375 Ok(transformed_plan)
1376}
1377
1378impl PushDownFilter {
1379 #[allow(missing_docs)]
1380 pub fn new() -> Self {
1381 Self {}
1382 }
1383}
1384
1385pub fn replace_cols_by_name(
1387 e: Expr,
1388 replace_map: &HashMap<String, Expr>,
1389) -> Result<Expr> {
1390 e.transform_up(|expr| {
1391 Ok(if let Expr::Column(c) = &expr {
1392 match replace_map.get(&c.flat_name()) {
1393 Some(new_c) => Transformed::yes(new_c.clone()),
1394 None => Transformed::no(expr),
1395 }
1396 } else {
1397 Transformed::no(expr)
1398 })
1399 })
1400 .data()
1401}
1402
1403fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1405 let mut is_contain = false;
1406 e.apply(|expr| {
1407 Ok(if let Expr::Column(c) = &expr {
1408 match check_map.get(&c.flat_name()) {
1409 Some(_) => {
1410 is_contain = true;
1411 TreeNodeRecursion::Stop
1412 }
1413 None => TreeNodeRecursion::Continue,
1414 }
1415 } else {
1416 TreeNodeRecursion::Continue
1417 })
1418 })
1419 .unwrap();
1420 is_contain
1421}
1422
1423#[cfg(test)]
1424mod tests {
1425 use std::any::Any;
1426 use std::cmp::Ordering;
1427 use std::fmt::{Debug, Formatter};
1428
1429 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1430 use async_trait::async_trait;
1431
1432 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1433 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1434 use datafusion_expr::logical_plan::table_scan;
1435 use datafusion_expr::{
1436 col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1437 LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1438 TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1439 WindowFunctionDefinition,
1440 };
1441
1442 use crate::assert_optimized_plan_eq_snapshot;
1443 use crate::optimizer::Optimizer;
1444 use crate::simplify_expressions::SimplifyExpressions;
1445 use crate::test::*;
1446 use crate::OptimizerContext;
1447 use datafusion_expr::test::function_stub::sum;
1448 use insta::assert_snapshot;
1449
1450 use super::*;
1451
1452 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1453
1454 macro_rules! assert_optimized_plan_equal {
1455 (
1456 $plan:expr,
1457 @ $expected:literal $(,)?
1458 ) => {{
1459 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1460 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1461 assert_optimized_plan_eq_snapshot!(
1462 optimizer_ctx,
1463 rules,
1464 $plan,
1465 @ $expected,
1466 )
1467 }};
1468 }
1469
1470 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1471 (
1472 $plan:expr,
1473 @ $expected:literal $(,)?
1474 ) => {{
1475 let optimizer = Optimizer::with_rules(vec![
1476 Arc::new(SimplifyExpressions::new()),
1477 Arc::new(PushDownFilter::new()),
1478 ]);
1479 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1480 assert_snapshot!(optimized_plan, @ $expected);
1481 Ok::<(), DataFusionError>(())
1482 }};
1483 }
1484
1485 #[test]
1486 fn filter_before_projection() -> Result<()> {
1487 let table_scan = test_table_scan()?;
1488 let plan = LogicalPlanBuilder::from(table_scan)
1489 .project(vec![col("a"), col("b")])?
1490 .filter(col("a").eq(lit(1i64)))?
1491 .build()?;
1492 assert_optimized_plan_equal!(
1494 plan,
1495 @r"
1496 Projection: test.a, test.b
1497 TableScan: test, full_filters=[test.a = Int64(1)]
1498 "
1499 )
1500 }
1501
1502 #[test]
1503 fn filter_after_limit() -> Result<()> {
1504 let table_scan = test_table_scan()?;
1505 let plan = LogicalPlanBuilder::from(table_scan)
1506 .project(vec![col("a"), col("b")])?
1507 .limit(0, Some(10))?
1508 .filter(col("a").eq(lit(1i64)))?
1509 .build()?;
1510 assert_optimized_plan_equal!(
1512 plan,
1513 @r"
1514 Filter: test.a = Int64(1)
1515 Limit: skip=0, fetch=10
1516 Projection: test.a, test.b
1517 TableScan: test
1518 "
1519 )
1520 }
1521
1522 #[test]
1523 fn filter_no_columns() -> Result<()> {
1524 let table_scan = test_table_scan()?;
1525 let plan = LogicalPlanBuilder::from(table_scan)
1526 .filter(lit(0i64).eq(lit(1i64)))?
1527 .build()?;
1528 assert_optimized_plan_equal!(
1529 plan,
1530 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1531 )
1532 }
1533
1534 #[test]
1535 fn filter_jump_2_plans() -> Result<()> {
1536 let table_scan = test_table_scan()?;
1537 let plan = LogicalPlanBuilder::from(table_scan)
1538 .project(vec![col("a"), col("b"), col("c")])?
1539 .project(vec![col("c"), col("b")])?
1540 .filter(col("a").eq(lit(1i64)))?
1541 .build()?;
1542 assert_optimized_plan_equal!(
1544 plan,
1545 @r"
1546 Projection: test.c, test.b
1547 Projection: test.a, test.b, test.c
1548 TableScan: test, full_filters=[test.a = Int64(1)]
1549 "
1550 )
1551 }
1552
1553 #[test]
1554 fn filter_move_agg() -> Result<()> {
1555 let table_scan = test_table_scan()?;
1556 let plan = LogicalPlanBuilder::from(table_scan)
1557 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1558 .filter(col("a").gt(lit(10i64)))?
1559 .build()?;
1560 assert_optimized_plan_equal!(
1562 plan,
1563 @r"
1564 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1565 TableScan: test, full_filters=[test.a > Int64(10)]
1566 "
1567 )
1568 }
1569
1570 #[test]
1571 fn filter_complex_group_by() -> Result<()> {
1572 let table_scan = test_table_scan()?;
1573 let plan = LogicalPlanBuilder::from(table_scan)
1574 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1575 .filter(col("b").gt(lit(10i64)))?
1576 .build()?;
1577 assert_optimized_plan_equal!(
1578 plan,
1579 @r"
1580 Filter: test.b > Int64(10)
1581 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1582 TableScan: test
1583 "
1584 )
1585 }
1586
1587 #[test]
1588 fn push_agg_need_replace_expr() -> Result<()> {
1589 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1590 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1591 .filter(col("test.b + test.a").gt(lit(10i64)))?
1592 .build()?;
1593 assert_optimized_plan_equal!(
1594 plan,
1595 @r"
1596 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1597 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1598 "
1599 )
1600 }
1601
1602 #[test]
1603 fn filter_keep_agg() -> Result<()> {
1604 let table_scan = test_table_scan()?;
1605 let plan = LogicalPlanBuilder::from(table_scan)
1606 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1607 .filter(col("b").gt(lit(10i64)))?
1608 .build()?;
1609 assert_optimized_plan_equal!(
1611 plan,
1612 @r"
1613 Filter: b > Int64(10)
1614 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1615 TableScan: test
1616 "
1617 )
1618 }
1619
1620 #[test]
1622 fn filter_move_window() -> Result<()> {
1623 let table_scan = test_table_scan()?;
1624
1625 let window = Expr::from(WindowFunction::new(
1626 WindowFunctionDefinition::WindowUDF(
1627 datafusion_functions_window::rank::rank_udwf(),
1628 ),
1629 vec![],
1630 ))
1631 .partition_by(vec![col("a"), col("b")])
1632 .order_by(vec![col("c").sort(true, true)])
1633 .build()
1634 .unwrap();
1635
1636 let plan = LogicalPlanBuilder::from(table_scan)
1637 .window(vec![window])?
1638 .filter(col("b").gt(lit(10i64)))?
1639 .build()?;
1640
1641 assert_optimized_plan_equal!(
1642 plan,
1643 @r"
1644 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1645 TableScan: test, full_filters=[test.b > Int64(10)]
1646 "
1647 )
1648 }
1649
1650 #[test]
1653 fn filter_move_complex_window() -> Result<()> {
1654 let table_scan = test_table_scan()?;
1655
1656 let window = Expr::from(WindowFunction::new(
1657 WindowFunctionDefinition::WindowUDF(
1658 datafusion_functions_window::rank::rank_udwf(),
1659 ),
1660 vec![],
1661 ))
1662 .partition_by(vec![col("a"), col("b")])
1663 .order_by(vec![col("c").sort(true, true)])
1664 .build()
1665 .unwrap();
1666
1667 let plan = LogicalPlanBuilder::from(table_scan)
1668 .window(vec![window])?
1669 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1670 .build()?;
1671
1672 assert_optimized_plan_equal!(
1673 plan,
1674 @r"
1675 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1676 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1677 "
1678 )
1679 }
1680
1681 #[test]
1683 fn filter_move_partial_window() -> Result<()> {
1684 let table_scan = test_table_scan()?;
1685
1686 let window = Expr::from(WindowFunction::new(
1687 WindowFunctionDefinition::WindowUDF(
1688 datafusion_functions_window::rank::rank_udwf(),
1689 ),
1690 vec![],
1691 ))
1692 .partition_by(vec![col("a")])
1693 .order_by(vec![col("c").sort(true, true)])
1694 .build()
1695 .unwrap();
1696
1697 let plan = LogicalPlanBuilder::from(table_scan)
1698 .window(vec![window])?
1699 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1700 .build()?;
1701
1702 assert_optimized_plan_equal!(
1703 plan,
1704 @r"
1705 Filter: test.b = Int64(1)
1706 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1707 TableScan: test, full_filters=[test.a > Int64(10)]
1708 "
1709 )
1710 }
1711
1712 #[test]
1715 fn filter_expression_keep_window() -> Result<()> {
1716 let table_scan = test_table_scan()?;
1717
1718 let window = Expr::from(WindowFunction::new(
1719 WindowFunctionDefinition::WindowUDF(
1720 datafusion_functions_window::rank::rank_udwf(),
1721 ),
1722 vec![],
1723 ))
1724 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1726 .build()
1727 .unwrap();
1728
1729 let plan = LogicalPlanBuilder::from(table_scan)
1730 .window(vec![window])?
1731 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1734 .build()?;
1735
1736 assert_optimized_plan_equal!(
1737 plan,
1738 @r"
1739 Filter: test.a + test.b > Int64(10)
1740 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1741 TableScan: test
1742 "
1743 )
1744 }
1745
1746 #[test]
1748 fn filter_order_keep_window() -> Result<()> {
1749 let table_scan = test_table_scan()?;
1750
1751 let window = Expr::from(WindowFunction::new(
1752 WindowFunctionDefinition::WindowUDF(
1753 datafusion_functions_window::rank::rank_udwf(),
1754 ),
1755 vec![],
1756 ))
1757 .partition_by(vec![col("a")])
1758 .order_by(vec![col("c").sort(true, true)])
1759 .build()
1760 .unwrap();
1761
1762 let plan = LogicalPlanBuilder::from(table_scan)
1763 .window(vec![window])?
1764 .filter(col("c").gt(lit(10i64)))?
1765 .build()?;
1766
1767 assert_optimized_plan_equal!(
1768 plan,
1769 @r"
1770 Filter: test.c > Int64(10)
1771 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1772 TableScan: test
1773 "
1774 )
1775 }
1776
1777 #[test]
1780 fn filter_multiple_windows_common_partitions() -> Result<()> {
1781 let table_scan = test_table_scan()?;
1782
1783 let window1 = Expr::from(WindowFunction::new(
1784 WindowFunctionDefinition::WindowUDF(
1785 datafusion_functions_window::rank::rank_udwf(),
1786 ),
1787 vec![],
1788 ))
1789 .partition_by(vec![col("a")])
1790 .order_by(vec![col("c").sort(true, true)])
1791 .build()
1792 .unwrap();
1793
1794 let window2 = Expr::from(WindowFunction::new(
1795 WindowFunctionDefinition::WindowUDF(
1796 datafusion_functions_window::rank::rank_udwf(),
1797 ),
1798 vec![],
1799 ))
1800 .partition_by(vec![col("b"), col("a")])
1801 .order_by(vec![col("c").sort(true, true)])
1802 .build()
1803 .unwrap();
1804
1805 let plan = LogicalPlanBuilder::from(table_scan)
1806 .window(vec![window1, window2])?
1807 .filter(col("a").gt(lit(10i64)))? .build()?;
1809
1810 assert_optimized_plan_equal!(
1811 plan,
1812 @r"
1813 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1814 TableScan: test, full_filters=[test.a > Int64(10)]
1815 "
1816 )
1817 }
1818
1819 #[test]
1822 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1823 let table_scan = test_table_scan()?;
1824
1825 let window1 = Expr::from(WindowFunction::new(
1826 WindowFunctionDefinition::WindowUDF(
1827 datafusion_functions_window::rank::rank_udwf(),
1828 ),
1829 vec![],
1830 ))
1831 .partition_by(vec![col("a")])
1832 .order_by(vec![col("c").sort(true, true)])
1833 .build()
1834 .unwrap();
1835
1836 let window2 = Expr::from(WindowFunction::new(
1837 WindowFunctionDefinition::WindowUDF(
1838 datafusion_functions_window::rank::rank_udwf(),
1839 ),
1840 vec![],
1841 ))
1842 .partition_by(vec![col("b"), col("a")])
1843 .order_by(vec![col("c").sort(true, true)])
1844 .build()
1845 .unwrap();
1846
1847 let plan = LogicalPlanBuilder::from(table_scan)
1848 .window(vec![window1, window2])?
1849 .filter(col("b").gt(lit(10i64)))? .build()?;
1851
1852 assert_optimized_plan_equal!(
1853 plan,
1854 @r"
1855 Filter: test.b > Int64(10)
1856 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1857 TableScan: test
1858 "
1859 )
1860 }
1861
1862 #[test]
1864 fn alias() -> Result<()> {
1865 let table_scan = test_table_scan()?;
1866 let plan = LogicalPlanBuilder::from(table_scan)
1867 .project(vec![col("a").alias("b"), col("c")])?
1868 .filter(col("b").eq(lit(1i64)))?
1869 .build()?;
1870 assert_optimized_plan_equal!(
1872 plan,
1873 @r"
1874 Projection: test.a AS b, test.c
1875 TableScan: test, full_filters=[test.a = Int64(1)]
1876 "
1877 )
1878 }
1879
1880 fn add(left: Expr, right: Expr) -> Expr {
1881 Expr::BinaryExpr(BinaryExpr::new(
1882 Box::new(left),
1883 Operator::Plus,
1884 Box::new(right),
1885 ))
1886 }
1887
1888 fn multiply(left: Expr, right: Expr) -> Expr {
1889 Expr::BinaryExpr(BinaryExpr::new(
1890 Box::new(left),
1891 Operator::Multiply,
1892 Box::new(right),
1893 ))
1894 }
1895
1896 #[test]
1898 fn complex_expression() -> Result<()> {
1899 let table_scan = test_table_scan()?;
1900 let plan = LogicalPlanBuilder::from(table_scan)
1901 .project(vec![
1902 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1903 col("c"),
1904 ])?
1905 .filter(col("b").eq(lit(1i64)))?
1906 .build()?;
1907
1908 assert_snapshot!(plan,
1910 @r"
1911 Filter: b = Int64(1)
1912 Projection: test.a * Int32(2) + test.c AS b, test.c
1913 TableScan: test
1914 ",
1915 );
1916 assert_optimized_plan_equal!(
1918 plan,
1919 @r"
1920 Projection: test.a * Int32(2) + test.c AS b, test.c
1921 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1922 "
1923 )
1924 }
1925
1926 #[test]
1928 fn complex_plan() -> Result<()> {
1929 let table_scan = test_table_scan()?;
1930 let plan = LogicalPlanBuilder::from(table_scan)
1931 .project(vec![
1932 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1933 col("c"),
1934 ])?
1935 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1937 .filter(col("a").eq(lit(1i64)))?
1938 .build()?;
1939
1940 assert_snapshot!(plan,
1942 @r"
1943 Filter: a = Int64(1)
1944 Projection: b * Int32(3) AS a, test.c
1945 Projection: test.a * Int32(2) + test.c AS b, test.c
1946 TableScan: test
1947 ",
1948 );
1949 assert_optimized_plan_equal!(
1951 plan,
1952 @r"
1953 Projection: b * Int32(3) AS a, test.c
1954 Projection: test.a * Int32(2) + test.c AS b, test.c
1955 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
1956 "
1957 )
1958 }
1959
1960 #[derive(Debug, PartialEq, Eq, Hash)]
1961 struct NoopPlan {
1962 input: Vec<LogicalPlan>,
1963 schema: DFSchemaRef,
1964 }
1965
1966 impl PartialOrd for NoopPlan {
1968 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1969 self.input.partial_cmp(&other.input)
1970 }
1971 }
1972
1973 impl UserDefinedLogicalNodeCore for NoopPlan {
1974 fn name(&self) -> &str {
1975 "NoopPlan"
1976 }
1977
1978 fn inputs(&self) -> Vec<&LogicalPlan> {
1979 self.input.iter().collect()
1980 }
1981
1982 fn schema(&self) -> &DFSchemaRef {
1983 &self.schema
1984 }
1985
1986 fn expressions(&self) -> Vec<Expr> {
1987 self.input
1988 .iter()
1989 .flat_map(|child| child.expressions())
1990 .collect()
1991 }
1992
1993 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
1994 HashSet::from_iter(vec!["c".to_string()])
1995 }
1996
1997 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1998 write!(f, "NoopPlan")
1999 }
2000
2001 fn with_exprs_and_inputs(
2002 &self,
2003 _exprs: Vec<Expr>,
2004 inputs: Vec<LogicalPlan>,
2005 ) -> Result<Self> {
2006 Ok(Self {
2007 input: inputs,
2008 schema: Arc::clone(&self.schema),
2009 })
2010 }
2011
2012 fn supports_limit_pushdown(&self) -> bool {
2013 false }
2015 }
2016
2017 #[test]
2018 fn user_defined_plan() -> Result<()> {
2019 let table_scan = test_table_scan()?;
2020
2021 let custom_plan = LogicalPlan::Extension(Extension {
2022 node: Arc::new(NoopPlan {
2023 input: vec![table_scan.clone()],
2024 schema: Arc::clone(table_scan.schema()),
2025 }),
2026 });
2027 let plan = LogicalPlanBuilder::from(custom_plan)
2028 .filter(col("a").eq(lit(1i64)))?
2029 .build()?;
2030
2031 assert_optimized_plan_equal!(
2033 plan,
2034 @r"
2035 NoopPlan
2036 TableScan: test, full_filters=[test.a = Int64(1)]
2037 "
2038 )?;
2039
2040 let custom_plan = LogicalPlan::Extension(Extension {
2041 node: Arc::new(NoopPlan {
2042 input: vec![table_scan.clone()],
2043 schema: Arc::clone(table_scan.schema()),
2044 }),
2045 });
2046 let plan = LogicalPlanBuilder::from(custom_plan)
2047 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2048 .build()?;
2049
2050 assert_optimized_plan_equal!(
2052 plan,
2053 @r"
2054 Filter: test.c = Int64(2)
2055 NoopPlan
2056 TableScan: test, full_filters=[test.a = Int64(1)]
2057 "
2058 )?;
2059
2060 let custom_plan = LogicalPlan::Extension(Extension {
2061 node: Arc::new(NoopPlan {
2062 input: vec![table_scan.clone(), table_scan.clone()],
2063 schema: Arc::clone(table_scan.schema()),
2064 }),
2065 });
2066 let plan = LogicalPlanBuilder::from(custom_plan)
2067 .filter(col("a").eq(lit(1i64)))?
2068 .build()?;
2069
2070 assert_optimized_plan_equal!(
2072 plan,
2073 @r"
2074 NoopPlan
2075 TableScan: test, full_filters=[test.a = Int64(1)]
2076 TableScan: test, full_filters=[test.a = Int64(1)]
2077 "
2078 )?;
2079
2080 let custom_plan = LogicalPlan::Extension(Extension {
2081 node: Arc::new(NoopPlan {
2082 input: vec![table_scan.clone(), table_scan.clone()],
2083 schema: Arc::clone(table_scan.schema()),
2084 }),
2085 });
2086 let plan = LogicalPlanBuilder::from(custom_plan)
2087 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2088 .build()?;
2089
2090 assert_optimized_plan_equal!(
2092 plan,
2093 @r"
2094 Filter: test.c = Int64(2)
2095 NoopPlan
2096 TableScan: test, full_filters=[test.a = Int64(1)]
2097 TableScan: test, full_filters=[test.a = Int64(1)]
2098 "
2099 )
2100 }
2101
2102 #[test]
2105 fn multi_filter() -> Result<()> {
2106 let table_scan = test_table_scan()?;
2108 let plan = LogicalPlanBuilder::from(table_scan)
2109 .project(vec![col("a").alias("b"), col("c")])?
2110 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2111 .filter(col("b").gt(lit(10i64)))?
2112 .filter(col("sum(test.c)").gt(lit(10i64)))?
2113 .build()?;
2114
2115 assert_snapshot!(plan,
2117 @r"
2118 Filter: sum(test.c) > Int64(10)
2119 Filter: b > Int64(10)
2120 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2121 Projection: test.a AS b, test.c
2122 TableScan: test
2123 ",
2124 );
2125 assert_optimized_plan_equal!(
2127 plan,
2128 @r"
2129 Filter: sum(test.c) > Int64(10)
2130 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2131 Projection: test.a AS b, test.c
2132 TableScan: test, full_filters=[test.a > Int64(10)]
2133 "
2134 )
2135 }
2136
2137 #[test]
2140 fn split_filter() -> Result<()> {
2141 let table_scan = test_table_scan()?;
2143 let plan = LogicalPlanBuilder::from(table_scan)
2144 .project(vec![col("a").alias("b"), col("c")])?
2145 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2146 .filter(and(
2147 col("sum(test.c)").gt(lit(10i64)),
2148 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2149 ))?
2150 .build()?;
2151
2152 assert_snapshot!(plan,
2154 @r"
2155 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2156 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2157 Projection: test.a AS b, test.c
2158 TableScan: test
2159 ",
2160 );
2161 assert_optimized_plan_equal!(
2163 plan,
2164 @r"
2165 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2166 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2167 Projection: test.a AS b, test.c
2168 TableScan: test, full_filters=[test.a > Int64(10)]
2169 "
2170 )
2171 }
2172
2173 #[test]
2175 fn double_limit() -> Result<()> {
2176 let table_scan = test_table_scan()?;
2177 let plan = LogicalPlanBuilder::from(table_scan)
2178 .project(vec![col("a"), col("b")])?
2179 .limit(0, Some(20))?
2180 .limit(0, Some(10))?
2181 .project(vec![col("a"), col("b")])?
2182 .filter(col("a").eq(lit(1i64)))?
2183 .build()?;
2184 assert_optimized_plan_equal!(
2186 plan,
2187 @r"
2188 Projection: test.a, test.b
2189 Filter: test.a = Int64(1)
2190 Limit: skip=0, fetch=10
2191 Limit: skip=0, fetch=20
2192 Projection: test.a, test.b
2193 TableScan: test
2194 "
2195 )
2196 }
2197
2198 #[test]
2199 fn union_all() -> Result<()> {
2200 let table_scan = test_table_scan()?;
2201 let table_scan2 = test_table_scan_with_name("test2")?;
2202 let plan = LogicalPlanBuilder::from(table_scan)
2203 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2204 .filter(col("a").eq(lit(1i64)))?
2205 .build()?;
2206 assert_optimized_plan_equal!(
2208 plan,
2209 @r"
2210 Union
2211 TableScan: test, full_filters=[test.a = Int64(1)]
2212 TableScan: test2, full_filters=[test2.a = Int64(1)]
2213 "
2214 )
2215 }
2216
2217 #[test]
2218 fn union_all_on_projection() -> Result<()> {
2219 let table_scan = test_table_scan()?;
2220 let table = LogicalPlanBuilder::from(table_scan)
2221 .project(vec![col("a").alias("b")])?
2222 .alias("test2")?;
2223
2224 let plan = table
2225 .clone()
2226 .union(table.build()?)?
2227 .filter(col("b").eq(lit(1i64)))?
2228 .build()?;
2229
2230 assert_optimized_plan_equal!(
2232 plan,
2233 @r"
2234 Union
2235 SubqueryAlias: test2
2236 Projection: test.a AS b
2237 TableScan: test, full_filters=[test.a = Int64(1)]
2238 SubqueryAlias: test2
2239 Projection: test.a AS b
2240 TableScan: test, full_filters=[test.a = Int64(1)]
2241 "
2242 )
2243 }
2244
2245 #[test]
2246 fn test_union_different_schema() -> Result<()> {
2247 let left = LogicalPlanBuilder::from(test_table_scan()?)
2248 .project(vec![col("a"), col("b"), col("c")])?
2249 .build()?;
2250
2251 let schema = Schema::new(vec![
2252 Field::new("d", DataType::UInt32, false),
2253 Field::new("e", DataType::UInt32, false),
2254 Field::new("f", DataType::UInt32, false),
2255 ]);
2256 let right = table_scan(Some("test1"), &schema, None)?
2257 .project(vec![col("d"), col("e"), col("f")])?
2258 .build()?;
2259 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2260 let plan = LogicalPlanBuilder::from(left)
2261 .cross_join(right)?
2262 .project(vec![col("test.a"), col("test1.d")])?
2263 .filter(filter)?
2264 .build()?;
2265
2266 assert_optimized_plan_equal!(
2267 plan,
2268 @r"
2269 Projection: test.a, test1.d
2270 Cross Join:
2271 Projection: test.a, test.b, test.c
2272 TableScan: test, full_filters=[test.a = Int32(1)]
2273 Projection: test1.d, test1.e, test1.f
2274 TableScan: test1, full_filters=[test1.d > Int32(2)]
2275 "
2276 )
2277 }
2278
2279 #[test]
2280 fn test_project_same_name_different_qualifier() -> Result<()> {
2281 let table_scan = test_table_scan()?;
2282 let left = LogicalPlanBuilder::from(table_scan)
2283 .project(vec![col("a"), col("b"), col("c")])?
2284 .build()?;
2285 let right_table_scan = test_table_scan_with_name("test1")?;
2286 let right = LogicalPlanBuilder::from(right_table_scan)
2287 .project(vec![col("a"), col("b"), col("c")])?
2288 .build()?;
2289 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2290 let plan = LogicalPlanBuilder::from(left)
2291 .cross_join(right)?
2292 .project(vec![col("test.a"), col("test1.a")])?
2293 .filter(filter)?
2294 .build()?;
2295
2296 assert_optimized_plan_equal!(
2297 plan,
2298 @r"
2299 Projection: test.a, test1.a
2300 Cross Join:
2301 Projection: test.a, test.b, test.c
2302 TableScan: test, full_filters=[test.a = Int32(1)]
2303 Projection: test1.a, test1.b, test1.c
2304 TableScan: test1, full_filters=[test1.a > Int32(2)]
2305 "
2306 )
2307 }
2308
2309 #[test]
2311 fn filter_2_breaks_limits() -> Result<()> {
2312 let table_scan = test_table_scan()?;
2313 let plan = LogicalPlanBuilder::from(table_scan)
2314 .project(vec![col("a")])?
2315 .filter(col("a").lt_eq(lit(1i64)))?
2316 .limit(0, Some(1))?
2317 .project(vec![col("a")])?
2318 .filter(col("a").gt_eq(lit(1i64)))?
2319 .build()?;
2320 assert_snapshot!(plan,
2324 @r"
2325 Filter: test.a >= Int64(1)
2326 Projection: test.a
2327 Limit: skip=0, fetch=1
2328 Filter: test.a <= Int64(1)
2329 Projection: test.a
2330 TableScan: test
2331 ",
2332 );
2333 assert_optimized_plan_equal!(
2334 plan,
2335 @r"
2336 Projection: test.a
2337 Filter: test.a >= Int64(1)
2338 Limit: skip=0, fetch=1
2339 Projection: test.a
2340 TableScan: test, full_filters=[test.a <= Int64(1)]
2341 "
2342 )
2343 }
2344
2345 #[test]
2347 fn two_filters_on_same_depth() -> Result<()> {
2348 let table_scan = test_table_scan()?;
2349 let plan = LogicalPlanBuilder::from(table_scan)
2350 .limit(0, Some(1))?
2351 .filter(col("a").lt_eq(lit(1i64)))?
2352 .filter(col("a").gt_eq(lit(1i64)))?
2353 .project(vec![col("a")])?
2354 .build()?;
2355
2356 assert_snapshot!(plan,
2358 @r"
2359 Projection: test.a
2360 Filter: test.a >= Int64(1)
2361 Filter: test.a <= Int64(1)
2362 Limit: skip=0, fetch=1
2363 TableScan: test
2364 ",
2365 );
2366 assert_optimized_plan_equal!(
2367 plan,
2368 @r"
2369 Projection: test.a
2370 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2371 Limit: skip=0, fetch=1
2372 TableScan: test
2373 "
2374 )
2375 }
2376
2377 #[test]
2380 fn filters_user_defined_node() -> Result<()> {
2381 let table_scan = test_table_scan()?;
2382 let plan = LogicalPlanBuilder::from(table_scan)
2383 .filter(col("a").lt_eq(lit(1i64)))?
2384 .build()?;
2385
2386 let plan = user_defined::new(plan);
2387
2388 assert_snapshot!(plan,
2390 @r"
2391 TestUserDefined
2392 Filter: test.a <= Int64(1)
2393 TableScan: test
2394 ",
2395 );
2396 assert_optimized_plan_equal!(
2397 plan,
2398 @r"
2399 TestUserDefined
2400 TableScan: test, full_filters=[test.a <= Int64(1)]
2401 "
2402 )
2403 }
2404
2405 #[test]
2407 fn filter_on_join_on_common_independent() -> Result<()> {
2408 let table_scan = test_table_scan()?;
2409 let left = LogicalPlanBuilder::from(table_scan).build()?;
2410 let right_table_scan = test_table_scan_with_name("test2")?;
2411 let right = LogicalPlanBuilder::from(right_table_scan)
2412 .project(vec![col("a")])?
2413 .build()?;
2414 let plan = LogicalPlanBuilder::from(left)
2415 .join(
2416 right,
2417 JoinType::Inner,
2418 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2419 None,
2420 )?
2421 .filter(col("test.a").lt_eq(lit(1i64)))?
2422 .build()?;
2423
2424 assert_snapshot!(plan,
2426 @r"
2427 Filter: test.a <= Int64(1)
2428 Inner Join: test.a = test2.a
2429 TableScan: test
2430 Projection: test2.a
2431 TableScan: test2
2432 ",
2433 );
2434 assert_optimized_plan_equal!(
2436 plan,
2437 @r"
2438 Inner Join: test.a = test2.a
2439 TableScan: test, full_filters=[test.a <= Int64(1)]
2440 Projection: test2.a
2441 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2442 "
2443 )
2444 }
2445
2446 #[test]
2448 fn filter_using_join_on_common_independent() -> Result<()> {
2449 let table_scan = test_table_scan()?;
2450 let left = LogicalPlanBuilder::from(table_scan).build()?;
2451 let right_table_scan = test_table_scan_with_name("test2")?;
2452 let right = LogicalPlanBuilder::from(right_table_scan)
2453 .project(vec![col("a")])?
2454 .build()?;
2455 let plan = LogicalPlanBuilder::from(left)
2456 .join_using(
2457 right,
2458 JoinType::Inner,
2459 vec![Column::from_name("a".to_string())],
2460 )?
2461 .filter(col("a").lt_eq(lit(1i64)))?
2462 .build()?;
2463
2464 assert_snapshot!(plan,
2466 @r"
2467 Filter: test.a <= Int64(1)
2468 Inner Join: Using test.a = test2.a
2469 TableScan: test
2470 Projection: test2.a
2471 TableScan: test2
2472 ",
2473 );
2474 assert_optimized_plan_equal!(
2476 plan,
2477 @r"
2478 Inner Join: Using test.a = test2.a
2479 TableScan: test, full_filters=[test.a <= Int64(1)]
2480 Projection: test2.a
2481 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2482 "
2483 )
2484 }
2485
2486 #[test]
2488 fn filter_join_on_common_dependent() -> Result<()> {
2489 let table_scan = test_table_scan()?;
2490 let left = LogicalPlanBuilder::from(table_scan)
2491 .project(vec![col("a"), col("c")])?
2492 .build()?;
2493 let right_table_scan = test_table_scan_with_name("test2")?;
2494 let right = LogicalPlanBuilder::from(right_table_scan)
2495 .project(vec![col("a"), col("b")])?
2496 .build()?;
2497 let plan = LogicalPlanBuilder::from(left)
2498 .join(
2499 right,
2500 JoinType::Inner,
2501 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2502 None,
2503 )?
2504 .filter(col("c").lt_eq(col("b")))?
2505 .build()?;
2506
2507 assert_snapshot!(plan,
2509 @r"
2510 Filter: test.c <= test2.b
2511 Inner Join: test.a = test2.a
2512 Projection: test.a, test.c
2513 TableScan: test
2514 Projection: test2.a, test2.b
2515 TableScan: test2
2516 ",
2517 );
2518 assert_optimized_plan_equal!(
2520 plan,
2521 @r"
2522 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2523 Projection: test.a, test.c
2524 TableScan: test
2525 Projection: test2.a, test2.b
2526 TableScan: test2
2527 "
2528 )
2529 }
2530
2531 #[test]
2533 fn filter_join_on_one_side() -> Result<()> {
2534 let table_scan = test_table_scan()?;
2535 let left = LogicalPlanBuilder::from(table_scan)
2536 .project(vec![col("a"), col("b")])?
2537 .build()?;
2538 let table_scan_right = test_table_scan_with_name("test2")?;
2539 let right = LogicalPlanBuilder::from(table_scan_right)
2540 .project(vec![col("a"), col("c")])?
2541 .build()?;
2542
2543 let plan = LogicalPlanBuilder::from(left)
2544 .join(
2545 right,
2546 JoinType::Inner,
2547 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2548 None,
2549 )?
2550 .filter(col("b").lt_eq(lit(1i64)))?
2551 .build()?;
2552
2553 assert_snapshot!(plan,
2555 @r"
2556 Filter: test.b <= Int64(1)
2557 Inner Join: test.a = test2.a
2558 Projection: test.a, test.b
2559 TableScan: test
2560 Projection: test2.a, test2.c
2561 TableScan: test2
2562 ",
2563 );
2564 assert_optimized_plan_equal!(
2565 plan,
2566 @r"
2567 Inner Join: test.a = test2.a
2568 Projection: test.a, test.b
2569 TableScan: test, full_filters=[test.b <= Int64(1)]
2570 Projection: test2.a, test2.c
2571 TableScan: test2
2572 "
2573 )
2574 }
2575
2576 #[test]
2579 fn filter_using_left_join() -> Result<()> {
2580 let table_scan = test_table_scan()?;
2581 let left = LogicalPlanBuilder::from(table_scan).build()?;
2582 let right_table_scan = test_table_scan_with_name("test2")?;
2583 let right = LogicalPlanBuilder::from(right_table_scan)
2584 .project(vec![col("a")])?
2585 .build()?;
2586 let plan = LogicalPlanBuilder::from(left)
2587 .join_using(
2588 right,
2589 JoinType::Left,
2590 vec![Column::from_name("a".to_string())],
2591 )?
2592 .filter(col("test2.a").lt_eq(lit(1i64)))?
2593 .build()?;
2594
2595 assert_snapshot!(plan,
2597 @r"
2598 Filter: test2.a <= Int64(1)
2599 Left Join: Using test.a = test2.a
2600 TableScan: test
2601 Projection: test2.a
2602 TableScan: test2
2603 ",
2604 );
2605 assert_optimized_plan_equal!(
2607 plan,
2608 @r"
2609 Filter: test2.a <= Int64(1)
2610 Left Join: Using test.a = test2.a
2611 TableScan: test, full_filters=[test.a <= Int64(1)]
2612 Projection: test2.a
2613 TableScan: test2
2614 "
2615 )
2616 }
2617
2618 #[test]
2620 fn filter_using_right_join() -> Result<()> {
2621 let table_scan = test_table_scan()?;
2622 let left = LogicalPlanBuilder::from(table_scan).build()?;
2623 let right_table_scan = test_table_scan_with_name("test2")?;
2624 let right = LogicalPlanBuilder::from(right_table_scan)
2625 .project(vec![col("a")])?
2626 .build()?;
2627 let plan = LogicalPlanBuilder::from(left)
2628 .join_using(
2629 right,
2630 JoinType::Right,
2631 vec![Column::from_name("a".to_string())],
2632 )?
2633 .filter(col("test.a").lt_eq(lit(1i64)))?
2634 .build()?;
2635
2636 assert_snapshot!(plan,
2638 @r"
2639 Filter: test.a <= Int64(1)
2640 Right Join: Using test.a = test2.a
2641 TableScan: test
2642 Projection: test2.a
2643 TableScan: test2
2644 ",
2645 );
2646 assert_optimized_plan_equal!(
2648 plan,
2649 @r"
2650 Filter: test.a <= Int64(1)
2651 Right Join: Using test.a = test2.a
2652 TableScan: test
2653 Projection: test2.a
2654 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2655 "
2656 )
2657 }
2658
2659 #[test]
2662 fn filter_using_left_join_on_common() -> Result<()> {
2663 let table_scan = test_table_scan()?;
2664 let left = LogicalPlanBuilder::from(table_scan).build()?;
2665 let right_table_scan = test_table_scan_with_name("test2")?;
2666 let right = LogicalPlanBuilder::from(right_table_scan)
2667 .project(vec![col("a")])?
2668 .build()?;
2669 let plan = LogicalPlanBuilder::from(left)
2670 .join_using(
2671 right,
2672 JoinType::Left,
2673 vec![Column::from_name("a".to_string())],
2674 )?
2675 .filter(col("a").lt_eq(lit(1i64)))?
2676 .build()?;
2677
2678 assert_snapshot!(plan,
2680 @r"
2681 Filter: test.a <= Int64(1)
2682 Left Join: Using test.a = test2.a
2683 TableScan: test
2684 Projection: test2.a
2685 TableScan: test2
2686 ",
2687 );
2688 assert_optimized_plan_equal!(
2690 plan,
2691 @r"
2692 Left Join: Using test.a = test2.a
2693 TableScan: test, full_filters=[test.a <= Int64(1)]
2694 Projection: test2.a
2695 TableScan: test2
2696 "
2697 )
2698 }
2699
2700 #[test]
2703 fn filter_using_right_join_on_common() -> Result<()> {
2704 let table_scan = test_table_scan()?;
2705 let left = LogicalPlanBuilder::from(table_scan).build()?;
2706 let right_table_scan = test_table_scan_with_name("test2")?;
2707 let right = LogicalPlanBuilder::from(right_table_scan)
2708 .project(vec![col("a")])?
2709 .build()?;
2710 let plan = LogicalPlanBuilder::from(left)
2711 .join_using(
2712 right,
2713 JoinType::Right,
2714 vec![Column::from_name("a".to_string())],
2715 )?
2716 .filter(col("test2.a").lt_eq(lit(1i64)))?
2717 .build()?;
2718
2719 assert_snapshot!(plan,
2721 @r"
2722 Filter: test2.a <= Int64(1)
2723 Right Join: Using test.a = test2.a
2724 TableScan: test
2725 Projection: test2.a
2726 TableScan: test2
2727 ",
2728 );
2729 assert_optimized_plan_equal!(
2731 plan,
2732 @r"
2733 Right Join: Using test.a = test2.a
2734 TableScan: test
2735 Projection: test2.a
2736 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2737 "
2738 )
2739 }
2740
2741 #[test]
2743 fn join_on_with_filter() -> Result<()> {
2744 let table_scan = test_table_scan()?;
2745 let left = LogicalPlanBuilder::from(table_scan)
2746 .project(vec![col("a"), col("b"), col("c")])?
2747 .build()?;
2748 let right_table_scan = test_table_scan_with_name("test2")?;
2749 let right = LogicalPlanBuilder::from(right_table_scan)
2750 .project(vec![col("a"), col("b"), col("c")])?
2751 .build()?;
2752 let filter = col("test.c")
2753 .gt(lit(1u32))
2754 .and(col("test.b").lt(col("test2.b")))
2755 .and(col("test2.c").gt(lit(4u32)));
2756 let plan = LogicalPlanBuilder::from(left)
2757 .join(
2758 right,
2759 JoinType::Inner,
2760 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2761 Some(filter),
2762 )?
2763 .build()?;
2764
2765 assert_snapshot!(plan,
2767 @r"
2768 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2769 Projection: test.a, test.b, test.c
2770 TableScan: test
2771 Projection: test2.a, test2.b, test2.c
2772 TableScan: test2
2773 ",
2774 );
2775 assert_optimized_plan_equal!(
2776 plan,
2777 @r"
2778 Inner Join: test.a = test2.a Filter: test.b < test2.b
2779 Projection: test.a, test.b, test.c
2780 TableScan: test, full_filters=[test.c > UInt32(1)]
2781 Projection: test2.a, test2.b, test2.c
2782 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2783 "
2784 )
2785 }
2786
2787 #[test]
2789 fn join_filter_removed() -> Result<()> {
2790 let table_scan = test_table_scan()?;
2791 let left = LogicalPlanBuilder::from(table_scan)
2792 .project(vec![col("a"), col("b"), col("c")])?
2793 .build()?;
2794 let right_table_scan = test_table_scan_with_name("test2")?;
2795 let right = LogicalPlanBuilder::from(right_table_scan)
2796 .project(vec![col("a"), col("b"), col("c")])?
2797 .build()?;
2798 let filter = col("test.b")
2799 .gt(lit(1u32))
2800 .and(col("test2.c").gt(lit(4u32)));
2801 let plan = LogicalPlanBuilder::from(left)
2802 .join(
2803 right,
2804 JoinType::Inner,
2805 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2806 Some(filter),
2807 )?
2808 .build()?;
2809
2810 assert_snapshot!(plan,
2812 @r"
2813 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2814 Projection: test.a, test.b, test.c
2815 TableScan: test
2816 Projection: test2.a, test2.b, test2.c
2817 TableScan: test2
2818 ",
2819 );
2820 assert_optimized_plan_equal!(
2821 plan,
2822 @r"
2823 Inner Join: test.a = test2.a
2824 Projection: test.a, test.b, test.c
2825 TableScan: test, full_filters=[test.b > UInt32(1)]
2826 Projection: test2.a, test2.b, test2.c
2827 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2828 "
2829 )
2830 }
2831
2832 #[test]
2834 fn join_filter_on_common() -> Result<()> {
2835 let table_scan = test_table_scan()?;
2836 let left = LogicalPlanBuilder::from(table_scan)
2837 .project(vec![col("a")])?
2838 .build()?;
2839 let right_table_scan = test_table_scan_with_name("test2")?;
2840 let right = LogicalPlanBuilder::from(right_table_scan)
2841 .project(vec![col("b")])?
2842 .build()?;
2843 let filter = col("test.a").gt(lit(1u32));
2844 let plan = LogicalPlanBuilder::from(left)
2845 .join(
2846 right,
2847 JoinType::Inner,
2848 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2849 Some(filter),
2850 )?
2851 .build()?;
2852
2853 assert_snapshot!(plan,
2855 @r"
2856 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2857 Projection: test.a
2858 TableScan: test
2859 Projection: test2.b
2860 TableScan: test2
2861 ",
2862 );
2863 assert_optimized_plan_equal!(
2864 plan,
2865 @r"
2866 Inner Join: test.a = test2.b
2867 Projection: test.a
2868 TableScan: test, full_filters=[test.a > UInt32(1)]
2869 Projection: test2.b
2870 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2871 "
2872 )
2873 }
2874
2875 #[test]
2877 fn left_join_on_with_filter() -> Result<()> {
2878 let table_scan = test_table_scan()?;
2879 let left = LogicalPlanBuilder::from(table_scan)
2880 .project(vec![col("a"), col("b"), col("c")])?
2881 .build()?;
2882 let right_table_scan = test_table_scan_with_name("test2")?;
2883 let right = LogicalPlanBuilder::from(right_table_scan)
2884 .project(vec![col("a"), col("b"), col("c")])?
2885 .build()?;
2886 let filter = col("test.a")
2887 .gt(lit(1u32))
2888 .and(col("test.b").lt(col("test2.b")))
2889 .and(col("test2.c").gt(lit(4u32)));
2890 let plan = LogicalPlanBuilder::from(left)
2891 .join(
2892 right,
2893 JoinType::Left,
2894 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2895 Some(filter),
2896 )?
2897 .build()?;
2898
2899 assert_snapshot!(plan,
2901 @r"
2902 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2903 Projection: test.a, test.b, test.c
2904 TableScan: test
2905 Projection: test2.a, test2.b, test2.c
2906 TableScan: test2
2907 ",
2908 );
2909 assert_optimized_plan_equal!(
2910 plan,
2911 @r"
2912 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2913 Projection: test.a, test.b, test.c
2914 TableScan: test
2915 Projection: test2.a, test2.b, test2.c
2916 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2917 "
2918 )
2919 }
2920
2921 #[test]
2923 fn right_join_on_with_filter() -> Result<()> {
2924 let table_scan = test_table_scan()?;
2925 let left = LogicalPlanBuilder::from(table_scan)
2926 .project(vec![col("a"), col("b"), col("c")])?
2927 .build()?;
2928 let right_table_scan = test_table_scan_with_name("test2")?;
2929 let right = LogicalPlanBuilder::from(right_table_scan)
2930 .project(vec![col("a"), col("b"), col("c")])?
2931 .build()?;
2932 let filter = col("test.a")
2933 .gt(lit(1u32))
2934 .and(col("test.b").lt(col("test2.b")))
2935 .and(col("test2.c").gt(lit(4u32)));
2936 let plan = LogicalPlanBuilder::from(left)
2937 .join(
2938 right,
2939 JoinType::Right,
2940 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2941 Some(filter),
2942 )?
2943 .build()?;
2944
2945 assert_snapshot!(plan,
2947 @r"
2948 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2949 Projection: test.a, test.b, test.c
2950 TableScan: test
2951 Projection: test2.a, test2.b, test2.c
2952 TableScan: test2
2953 ",
2954 );
2955 assert_optimized_plan_equal!(
2956 plan,
2957 @r"
2958 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
2959 Projection: test.a, test.b, test.c
2960 TableScan: test, full_filters=[test.a > UInt32(1)]
2961 Projection: test2.a, test2.b, test2.c
2962 TableScan: test2
2963 "
2964 )
2965 }
2966
2967 #[test]
2969 fn full_join_on_with_filter() -> Result<()> {
2970 let table_scan = test_table_scan()?;
2971 let left = LogicalPlanBuilder::from(table_scan)
2972 .project(vec![col("a"), col("b"), col("c")])?
2973 .build()?;
2974 let right_table_scan = test_table_scan_with_name("test2")?;
2975 let right = LogicalPlanBuilder::from(right_table_scan)
2976 .project(vec![col("a"), col("b"), col("c")])?
2977 .build()?;
2978 let filter = col("test.a")
2979 .gt(lit(1u32))
2980 .and(col("test.b").lt(col("test2.b")))
2981 .and(col("test2.c").gt(lit(4u32)));
2982 let plan = LogicalPlanBuilder::from(left)
2983 .join(
2984 right,
2985 JoinType::Full,
2986 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2987 Some(filter),
2988 )?
2989 .build()?;
2990
2991 assert_snapshot!(plan,
2993 @r"
2994 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2995 Projection: test.a, test.b, test.c
2996 TableScan: test
2997 Projection: test2.a, test2.b, test2.c
2998 TableScan: test2
2999 ",
3000 );
3001 assert_optimized_plan_equal!(
3002 plan,
3003 @r"
3004 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3005 Projection: test.a, test.b, test.c
3006 TableScan: test
3007 Projection: test2.a, test2.b, test2.c
3008 TableScan: test2
3009 "
3010 )
3011 }
3012
3013 struct PushDownProvider {
3014 pub filter_support: TableProviderFilterPushDown,
3015 }
3016
3017 #[async_trait]
3018 impl TableSource for PushDownProvider {
3019 fn schema(&self) -> SchemaRef {
3020 Arc::new(Schema::new(vec![
3021 Field::new("a", DataType::Int32, true),
3022 Field::new("b", DataType::Int32, true),
3023 ]))
3024 }
3025
3026 fn table_type(&self) -> TableType {
3027 TableType::Base
3028 }
3029
3030 fn supports_filters_pushdown(
3031 &self,
3032 filters: &[&Expr],
3033 ) -> Result<Vec<TableProviderFilterPushDown>> {
3034 Ok((0..filters.len())
3035 .map(|_| self.filter_support.clone())
3036 .collect())
3037 }
3038
3039 fn as_any(&self) -> &dyn Any {
3040 self
3041 }
3042 }
3043
3044 fn table_scan_with_pushdown_provider_builder(
3045 filter_support: TableProviderFilterPushDown,
3046 filters: Vec<Expr>,
3047 projection: Option<Vec<usize>>,
3048 ) -> Result<LogicalPlanBuilder> {
3049 let test_provider = PushDownProvider { filter_support };
3050
3051 let table_scan = LogicalPlan::TableScan(TableScan {
3052 table_name: "test".into(),
3053 filters,
3054 projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3055 projection,
3056 source: Arc::new(test_provider),
3057 fetch: None,
3058 });
3059
3060 Ok(LogicalPlanBuilder::from(table_scan))
3061 }
3062
3063 fn table_scan_with_pushdown_provider(
3064 filter_support: TableProviderFilterPushDown,
3065 ) -> Result<LogicalPlan> {
3066 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3067 .filter(col("a").eq(lit(1i64)))?
3068 .build()
3069 }
3070
3071 #[test]
3072 fn filter_with_table_provider_exact() -> Result<()> {
3073 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3074
3075 assert_optimized_plan_equal!(
3076 plan,
3077 @"TableScan: test, full_filters=[a = Int64(1)]"
3078 )
3079 }
3080
3081 #[test]
3082 fn filter_with_table_provider_inexact() -> Result<()> {
3083 let plan =
3084 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3085
3086 assert_optimized_plan_equal!(
3087 plan,
3088 @r"
3089 Filter: a = Int64(1)
3090 TableScan: test, partial_filters=[a = Int64(1)]
3091 "
3092 )
3093 }
3094
3095 #[test]
3096 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3097 let plan =
3098 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3099
3100 let optimized_plan = PushDownFilter::new()
3101 .rewrite(plan, &OptimizerContext::new())
3102 .expect("failed to optimize plan")
3103 .data;
3104
3105 assert_optimized_plan_equal!(
3108 optimized_plan,
3109 @r"
3110 Filter: a = Int64(1)
3111 TableScan: test, partial_filters=[a = Int64(1)]
3112 "
3113 )
3114 }
3115
3116 #[test]
3117 fn filter_with_table_provider_unsupported() -> Result<()> {
3118 let plan =
3119 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3120
3121 assert_optimized_plan_equal!(
3122 plan,
3123 @r"
3124 Filter: a = Int64(1)
3125 TableScan: test
3126 "
3127 )
3128 }
3129
3130 #[test]
3131 fn multi_combined_filter() -> Result<()> {
3132 let plan = table_scan_with_pushdown_provider_builder(
3133 TableProviderFilterPushDown::Inexact,
3134 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3135 Some(vec![0]),
3136 )?
3137 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3138 .project(vec![col("a"), col("b")])?
3139 .build()?;
3140
3141 assert_optimized_plan_equal!(
3142 plan,
3143 @r"
3144 Projection: a, b
3145 Filter: a = Int64(10) AND b > Int64(11)
3146 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3147 "
3148 )
3149 }
3150
3151 #[test]
3152 fn multi_combined_filter_exact() -> Result<()> {
3153 let plan = table_scan_with_pushdown_provider_builder(
3154 TableProviderFilterPushDown::Exact,
3155 vec![],
3156 Some(vec![0]),
3157 )?
3158 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3159 .project(vec![col("a"), col("b")])?
3160 .build()?;
3161
3162 assert_optimized_plan_equal!(
3163 plan,
3164 @r"
3165 Projection: a, b
3166 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3167 "
3168 )
3169 }
3170
3171 #[test]
3172 fn test_filter_with_alias() -> Result<()> {
3173 let table_scan = test_table_scan()?;
3177 let plan = LogicalPlanBuilder::from(table_scan)
3178 .project(vec![col("a").alias("b"), col("c")])?
3179 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3180 .build()?;
3181
3182 assert_snapshot!(plan,
3184 @r"
3185 Filter: b > Int64(10) AND test.c > Int64(10)
3186 Projection: test.a AS b, test.c
3187 TableScan: test
3188 ",
3189 );
3190 assert_optimized_plan_equal!(
3192 plan,
3193 @r"
3194 Projection: test.a AS b, test.c
3195 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3196 "
3197 )
3198 }
3199
3200 #[test]
3201 fn test_filter_with_alias_2() -> Result<()> {
3202 let table_scan = test_table_scan()?;
3206 let plan = LogicalPlanBuilder::from(table_scan)
3207 .project(vec![col("a").alias("b"), col("c")])?
3208 .project(vec![col("b"), col("c")])?
3209 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3210 .build()?;
3211
3212 assert_snapshot!(plan,
3214 @r"
3215 Filter: b > Int64(10) AND test.c > Int64(10)
3216 Projection: b, test.c
3217 Projection: test.a AS b, test.c
3218 TableScan: test
3219 ",
3220 );
3221 assert_optimized_plan_equal!(
3223 plan,
3224 @r"
3225 Projection: b, test.c
3226 Projection: test.a AS b, test.c
3227 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3228 "
3229 )
3230 }
3231
3232 #[test]
3233 fn test_filter_with_multi_alias() -> Result<()> {
3234 let table_scan = test_table_scan()?;
3235 let plan = LogicalPlanBuilder::from(table_scan)
3236 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3237 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3238 .build()?;
3239
3240 assert_snapshot!(plan,
3242 @r"
3243 Filter: b > Int64(10) AND d > Int64(10)
3244 Projection: test.a AS b, test.c AS d
3245 TableScan: test
3246 ",
3247 );
3248 assert_optimized_plan_equal!(
3250 plan,
3251 @r"
3252 Projection: test.a AS b, test.c AS d
3253 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3254 "
3255 )
3256 }
3257
3258 #[test]
3260 fn join_filter_with_alias() -> Result<()> {
3261 let table_scan = test_table_scan()?;
3262 let left = LogicalPlanBuilder::from(table_scan)
3263 .project(vec![col("a").alias("c")])?
3264 .build()?;
3265 let right_table_scan = test_table_scan_with_name("test2")?;
3266 let right = LogicalPlanBuilder::from(right_table_scan)
3267 .project(vec![col("b").alias("d")])?
3268 .build()?;
3269 let filter = col("c").gt(lit(1u32));
3270 let plan = LogicalPlanBuilder::from(left)
3271 .join(
3272 right,
3273 JoinType::Inner,
3274 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3275 Some(filter),
3276 )?
3277 .build()?;
3278
3279 assert_snapshot!(plan,
3280 @r"
3281 Inner Join: c = d Filter: c > UInt32(1)
3282 Projection: test.a AS c
3283 TableScan: test
3284 Projection: test2.b AS d
3285 TableScan: test2
3286 ",
3287 );
3288 assert_optimized_plan_equal!(
3290 plan,
3291 @r"
3292 Inner Join: c = d
3293 Projection: test.a AS c
3294 TableScan: test, full_filters=[test.a > UInt32(1)]
3295 Projection: test2.b AS d
3296 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3297 "
3298 )
3299 }
3300
3301 #[test]
3302 fn test_in_filter_with_alias() -> Result<()> {
3303 let table_scan = test_table_scan()?;
3307 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3308 let plan = LogicalPlanBuilder::from(table_scan)
3309 .project(vec![col("a").alias("b"), col("c")])?
3310 .filter(in_list(col("b"), filter_value, false))?
3311 .build()?;
3312
3313 assert_snapshot!(plan,
3315 @r"
3316 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3317 Projection: test.a AS b, test.c
3318 TableScan: test
3319 ",
3320 );
3321 assert_optimized_plan_equal!(
3323 plan,
3324 @r"
3325 Projection: test.a AS b, test.c
3326 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3327 "
3328 )
3329 }
3330
3331 #[test]
3332 fn test_in_filter_with_alias_2() -> Result<()> {
3333 let table_scan = test_table_scan()?;
3337 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3338 let plan = LogicalPlanBuilder::from(table_scan)
3339 .project(vec![col("a").alias("b"), col("c")])?
3340 .project(vec![col("b"), col("c")])?
3341 .filter(in_list(col("b"), filter_value, false))?
3342 .build()?;
3343
3344 assert_snapshot!(plan,
3346 @r"
3347 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3348 Projection: b, test.c
3349 Projection: test.a AS b, test.c
3350 TableScan: test
3351 ",
3352 );
3353 assert_optimized_plan_equal!(
3355 plan,
3356 @r"
3357 Projection: b, test.c
3358 Projection: test.a AS b, test.c
3359 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3360 "
3361 )
3362 }
3363
3364 #[test]
3365 fn test_in_subquery_with_alias() -> Result<()> {
3366 let table_scan = test_table_scan()?;
3369 let table_scan_sq = test_table_scan_with_name("sq")?;
3370 let subplan = Arc::new(
3371 LogicalPlanBuilder::from(table_scan_sq)
3372 .project(vec![col("c")])?
3373 .build()?,
3374 );
3375 let plan = LogicalPlanBuilder::from(table_scan)
3376 .project(vec![col("a").alias("b"), col("c")])?
3377 .filter(in_subquery(col("b"), subplan))?
3378 .build()?;
3379
3380 assert_snapshot!(plan,
3382 @r"
3383 Filter: b IN (<subquery>)
3384 Subquery:
3385 Projection: sq.c
3386 TableScan: sq
3387 Projection: test.a AS b, test.c
3388 TableScan: test
3389 ",
3390 );
3391 assert_optimized_plan_equal!(
3393 plan,
3394 @r"
3395 Projection: test.a AS b, test.c
3396 TableScan: test, full_filters=[test.a IN (<subquery>)]
3397 Subquery:
3398 Projection: sq.c
3399 TableScan: sq
3400 "
3401 )
3402 }
3403
3404 #[test]
3405 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3406 let plan = LogicalPlanBuilder::empty(true)
3408 .project(vec![lit(0i64).alias("a")])?
3409 .alias("b")?
3410 .project(vec![col("b.a")])?
3411 .alias("b")?
3412 .filter(col("b.a").eq(lit(1i64)))?
3413 .project(vec![col("b.a")])?
3414 .build()?;
3415
3416 assert_snapshot!(plan,
3417 @r"
3418 Projection: b.a
3419 Filter: b.a = Int64(1)
3420 SubqueryAlias: b
3421 Projection: b.a
3422 SubqueryAlias: b
3423 Projection: Int64(0) AS a
3424 EmptyRelation: rows=1
3425 ",
3426 );
3427 assert_optimized_plan_equal!(
3430 plan,
3431 @r"
3432 Projection: b.a
3433 SubqueryAlias: b
3434 Projection: b.a
3435 SubqueryAlias: b
3436 Projection: Int64(0) AS a
3437 Filter: Int64(0) = Int64(1)
3438 EmptyRelation: rows=1
3439 "
3440 )
3441 }
3442
3443 #[test]
3444 fn test_crossjoin_with_or_clause() -> Result<()> {
3445 let table_scan = test_table_scan()?;
3447 let left = LogicalPlanBuilder::from(table_scan)
3448 .project(vec![col("a"), col("b"), col("c")])?
3449 .build()?;
3450 let right_table_scan = test_table_scan_with_name("test1")?;
3451 let right = LogicalPlanBuilder::from(right_table_scan)
3452 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3453 .build()?;
3454 let filter = or(
3455 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3456 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3457 );
3458 let plan = LogicalPlanBuilder::from(left)
3459 .cross_join(right)?
3460 .filter(filter)?
3461 .build()?;
3462
3463 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3464 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3465 Projection: test.a, test.b, test.c
3466 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3467 Projection: test1.a AS d, test1.a AS e
3468 TableScan: test1
3469 ")?;
3470
3471 let optimized_plan = PushDownFilter::new()
3474 .rewrite(plan, &OptimizerContext::new())
3475 .expect("failed to optimize plan")
3476 .data;
3477 assert_optimized_plan_equal!(
3478 optimized_plan,
3479 @r"
3480 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3481 Projection: test.a, test.b, test.c
3482 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3483 Projection: test1.a AS d, test1.a AS e
3484 TableScan: test1
3485 "
3486 )
3487 }
3488
3489 #[test]
3490 fn left_semi_join() -> Result<()> {
3491 let left = test_table_scan_with_name("test1")?;
3492 let right_table_scan = test_table_scan_with_name("test2")?;
3493 let right = LogicalPlanBuilder::from(right_table_scan)
3494 .project(vec![col("a"), col("b")])?
3495 .build()?;
3496 let plan = LogicalPlanBuilder::from(left)
3497 .join(
3498 right,
3499 JoinType::LeftSemi,
3500 (
3501 vec![Column::from_qualified_name("test1.a")],
3502 vec![Column::from_qualified_name("test2.a")],
3503 ),
3504 None,
3505 )?
3506 .filter(col("test2.a").lt_eq(lit(1i64)))?
3507 .build()?;
3508
3509 assert_snapshot!(plan,
3511 @r"
3512 Filter: test2.a <= Int64(1)
3513 LeftSemi Join: test1.a = test2.a
3514 TableScan: test1
3515 Projection: test2.a, test2.b
3516 TableScan: test2
3517 ",
3518 );
3519 assert_optimized_plan_equal!(
3521 plan,
3522 @r"
3523 Filter: test2.a <= Int64(1)
3524 LeftSemi Join: test1.a = test2.a
3525 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3526 Projection: test2.a, test2.b
3527 TableScan: test2
3528 "
3529 )
3530 }
3531
3532 #[test]
3533 fn left_semi_join_with_filters() -> Result<()> {
3534 let left = test_table_scan_with_name("test1")?;
3535 let right_table_scan = test_table_scan_with_name("test2")?;
3536 let right = LogicalPlanBuilder::from(right_table_scan)
3537 .project(vec![col("a"), col("b")])?
3538 .build()?;
3539 let plan = LogicalPlanBuilder::from(left)
3540 .join(
3541 right,
3542 JoinType::LeftSemi,
3543 (
3544 vec![Column::from_qualified_name("test1.a")],
3545 vec![Column::from_qualified_name("test2.a")],
3546 ),
3547 Some(
3548 col("test1.b")
3549 .gt(lit(1u32))
3550 .and(col("test2.b").gt(lit(2u32))),
3551 ),
3552 )?
3553 .build()?;
3554
3555 assert_snapshot!(plan,
3557 @r"
3558 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3559 TableScan: test1
3560 Projection: test2.a, test2.b
3561 TableScan: test2
3562 ",
3563 );
3564 assert_optimized_plan_equal!(
3566 plan,
3567 @r"
3568 LeftSemi Join: test1.a = test2.a
3569 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3570 Projection: test2.a, test2.b
3571 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3572 "
3573 )
3574 }
3575
3576 #[test]
3577 fn right_semi_join() -> Result<()> {
3578 let left = test_table_scan_with_name("test1")?;
3579 let right_table_scan = test_table_scan_with_name("test2")?;
3580 let right = LogicalPlanBuilder::from(right_table_scan)
3581 .project(vec![col("a"), col("b")])?
3582 .build()?;
3583 let plan = LogicalPlanBuilder::from(left)
3584 .join(
3585 right,
3586 JoinType::RightSemi,
3587 (
3588 vec![Column::from_qualified_name("test1.a")],
3589 vec![Column::from_qualified_name("test2.a")],
3590 ),
3591 None,
3592 )?
3593 .filter(col("test1.a").lt_eq(lit(1i64)))?
3594 .build()?;
3595
3596 assert_snapshot!(plan,
3598 @r"
3599 Filter: test1.a <= Int64(1)
3600 RightSemi Join: test1.a = test2.a
3601 TableScan: test1
3602 Projection: test2.a, test2.b
3603 TableScan: test2
3604 ",
3605 );
3606 assert_optimized_plan_equal!(
3608 plan,
3609 @r"
3610 Filter: test1.a <= Int64(1)
3611 RightSemi Join: test1.a = test2.a
3612 TableScan: test1
3613 Projection: test2.a, test2.b
3614 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3615 "
3616 )
3617 }
3618
3619 #[test]
3620 fn right_semi_join_with_filters() -> Result<()> {
3621 let left = test_table_scan_with_name("test1")?;
3622 let right_table_scan = test_table_scan_with_name("test2")?;
3623 let right = LogicalPlanBuilder::from(right_table_scan)
3624 .project(vec![col("a"), col("b")])?
3625 .build()?;
3626 let plan = LogicalPlanBuilder::from(left)
3627 .join(
3628 right,
3629 JoinType::RightSemi,
3630 (
3631 vec![Column::from_qualified_name("test1.a")],
3632 vec![Column::from_qualified_name("test2.a")],
3633 ),
3634 Some(
3635 col("test1.b")
3636 .gt(lit(1u32))
3637 .and(col("test2.b").gt(lit(2u32))),
3638 ),
3639 )?
3640 .build()?;
3641
3642 assert_snapshot!(plan,
3644 @r"
3645 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3646 TableScan: test1
3647 Projection: test2.a, test2.b
3648 TableScan: test2
3649 ",
3650 );
3651 assert_optimized_plan_equal!(
3653 plan,
3654 @r"
3655 RightSemi Join: test1.a = test2.a
3656 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3657 Projection: test2.a, test2.b
3658 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3659 "
3660 )
3661 }
3662
3663 #[test]
3664 fn left_anti_join() -> Result<()> {
3665 let table_scan = test_table_scan_with_name("test1")?;
3666 let left = LogicalPlanBuilder::from(table_scan)
3667 .project(vec![col("a"), col("b")])?
3668 .build()?;
3669 let right_table_scan = test_table_scan_with_name("test2")?;
3670 let right = LogicalPlanBuilder::from(right_table_scan)
3671 .project(vec![col("a"), col("b")])?
3672 .build()?;
3673 let plan = LogicalPlanBuilder::from(left)
3674 .join(
3675 right,
3676 JoinType::LeftAnti,
3677 (
3678 vec![Column::from_qualified_name("test1.a")],
3679 vec![Column::from_qualified_name("test2.a")],
3680 ),
3681 None,
3682 )?
3683 .filter(col("test2.a").gt(lit(2u32)))?
3684 .build()?;
3685
3686 assert_snapshot!(plan,
3688 @r"
3689 Filter: test2.a > UInt32(2)
3690 LeftAnti Join: test1.a = test2.a
3691 Projection: test1.a, test1.b
3692 TableScan: test1
3693 Projection: test2.a, test2.b
3694 TableScan: test2
3695 ",
3696 );
3697 assert_optimized_plan_equal!(
3699 plan,
3700 @r"
3701 Filter: test2.a > UInt32(2)
3702 LeftAnti Join: test1.a = test2.a
3703 Projection: test1.a, test1.b
3704 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3705 Projection: test2.a, test2.b
3706 TableScan: test2
3707 "
3708 )
3709 }
3710
3711 #[test]
3712 fn left_anti_join_with_filters() -> Result<()> {
3713 let table_scan = test_table_scan_with_name("test1")?;
3714 let left = LogicalPlanBuilder::from(table_scan)
3715 .project(vec![col("a"), col("b")])?
3716 .build()?;
3717 let right_table_scan = test_table_scan_with_name("test2")?;
3718 let right = LogicalPlanBuilder::from(right_table_scan)
3719 .project(vec![col("a"), col("b")])?
3720 .build()?;
3721 let plan = LogicalPlanBuilder::from(left)
3722 .join(
3723 right,
3724 JoinType::LeftAnti,
3725 (
3726 vec![Column::from_qualified_name("test1.a")],
3727 vec![Column::from_qualified_name("test2.a")],
3728 ),
3729 Some(
3730 col("test1.b")
3731 .gt(lit(1u32))
3732 .and(col("test2.b").gt(lit(2u32))),
3733 ),
3734 )?
3735 .build()?;
3736
3737 assert_snapshot!(plan,
3739 @r"
3740 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3741 Projection: test1.a, test1.b
3742 TableScan: test1
3743 Projection: test2.a, test2.b
3744 TableScan: test2
3745 ",
3746 );
3747 assert_optimized_plan_equal!(
3749 plan,
3750 @r"
3751 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3752 Projection: test1.a, test1.b
3753 TableScan: test1
3754 Projection: test2.a, test2.b
3755 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3756 "
3757 )
3758 }
3759
3760 #[test]
3761 fn right_anti_join() -> Result<()> {
3762 let table_scan = test_table_scan_with_name("test1")?;
3763 let left = LogicalPlanBuilder::from(table_scan)
3764 .project(vec![col("a"), col("b")])?
3765 .build()?;
3766 let right_table_scan = test_table_scan_with_name("test2")?;
3767 let right = LogicalPlanBuilder::from(right_table_scan)
3768 .project(vec![col("a"), col("b")])?
3769 .build()?;
3770 let plan = LogicalPlanBuilder::from(left)
3771 .join(
3772 right,
3773 JoinType::RightAnti,
3774 (
3775 vec![Column::from_qualified_name("test1.a")],
3776 vec![Column::from_qualified_name("test2.a")],
3777 ),
3778 None,
3779 )?
3780 .filter(col("test1.a").gt(lit(2u32)))?
3781 .build()?;
3782
3783 assert_snapshot!(plan,
3785 @r"
3786 Filter: test1.a > UInt32(2)
3787 RightAnti Join: test1.a = test2.a
3788 Projection: test1.a, test1.b
3789 TableScan: test1
3790 Projection: test2.a, test2.b
3791 TableScan: test2
3792 ",
3793 );
3794 assert_optimized_plan_equal!(
3796 plan,
3797 @r"
3798 Filter: test1.a > UInt32(2)
3799 RightAnti Join: test1.a = test2.a
3800 Projection: test1.a, test1.b
3801 TableScan: test1
3802 Projection: test2.a, test2.b
3803 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3804 "
3805 )
3806 }
3807
3808 #[test]
3809 fn right_anti_join_with_filters() -> Result<()> {
3810 let table_scan = test_table_scan_with_name("test1")?;
3811 let left = LogicalPlanBuilder::from(table_scan)
3812 .project(vec![col("a"), col("b")])?
3813 .build()?;
3814 let right_table_scan = test_table_scan_with_name("test2")?;
3815 let right = LogicalPlanBuilder::from(right_table_scan)
3816 .project(vec![col("a"), col("b")])?
3817 .build()?;
3818 let plan = LogicalPlanBuilder::from(left)
3819 .join(
3820 right,
3821 JoinType::RightAnti,
3822 (
3823 vec![Column::from_qualified_name("test1.a")],
3824 vec![Column::from_qualified_name("test2.a")],
3825 ),
3826 Some(
3827 col("test1.b")
3828 .gt(lit(1u32))
3829 .and(col("test2.b").gt(lit(2u32))),
3830 ),
3831 )?
3832 .build()?;
3833
3834 assert_snapshot!(plan,
3836 @r"
3837 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3838 Projection: test1.a, test1.b
3839 TableScan: test1
3840 Projection: test2.a, test2.b
3841 TableScan: test2
3842 ",
3843 );
3844 assert_optimized_plan_equal!(
3846 plan,
3847 @r"
3848 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3849 Projection: test1.a, test1.b
3850 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3851 Projection: test2.a, test2.b
3852 TableScan: test2
3853 "
3854 )
3855 }
3856
3857 #[derive(Debug, PartialEq, Eq, Hash)]
3858 struct TestScalarUDF {
3859 signature: Signature,
3860 }
3861
3862 impl ScalarUDFImpl for TestScalarUDF {
3863 fn as_any(&self) -> &dyn Any {
3864 self
3865 }
3866 fn name(&self) -> &str {
3867 "TestScalarUDF"
3868 }
3869
3870 fn signature(&self) -> &Signature {
3871 &self.signature
3872 }
3873
3874 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3875 Ok(DataType::Int32)
3876 }
3877
3878 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3879 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3880 }
3881 }
3882
3883 #[test]
3884 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3885 let table_scan = test_table_scan_with_name("test1")?;
3887 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3888 signature: Signature::exact(vec![], Volatility::Volatile),
3889 });
3890 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3891
3892 let plan = LogicalPlanBuilder::from(table_scan)
3893 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3894 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3895 .alias("t")?
3896 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3897 .project(vec![col("t.a"), col("t.r")])?
3898 .build()?;
3899
3900 assert_snapshot!(plan,
3901 @r"
3902 Projection: t.a, t.r
3903 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3904 SubqueryAlias: t
3905 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3906 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3907 TableScan: test1
3908 ",
3909 );
3910 assert_optimized_plan_equal!(
3911 plan,
3912 @r"
3913 Projection: t.a, t.r
3914 SubqueryAlias: t
3915 Filter: r > Float64(0.5)
3916 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3917 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3918 TableScan: test1, full_filters=[test1.a > Int32(5)]
3919 "
3920 )
3921 }
3922
3923 #[test]
3924 fn test_push_down_volatile_function_in_join() -> Result<()> {
3925 let table_scan = test_table_scan_with_name("test1")?;
3927 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3928 signature: Signature::exact(vec![], Volatility::Volatile),
3929 });
3930 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3931 let left = LogicalPlanBuilder::from(table_scan).build()?;
3932 let right_table_scan = test_table_scan_with_name("test2")?;
3933 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3934 let plan = LogicalPlanBuilder::from(left)
3935 .join(
3936 right,
3937 JoinType::Inner,
3938 (
3939 vec![Column::from_qualified_name("test1.a")],
3940 vec![Column::from_qualified_name("test2.a")],
3941 ),
3942 None,
3943 )?
3944 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
3945 .alias("t")?
3946 .filter(col("t.r").gt(lit(0.8)))?
3947 .project(vec![col("t.a"), col("t.r")])?
3948 .build()?;
3949
3950 assert_snapshot!(plan,
3951 @r"
3952 Projection: t.a, t.r
3953 Filter: t.r > Float64(0.8)
3954 SubqueryAlias: t
3955 Projection: test1.a AS a, TestScalarUDF() AS r
3956 Inner Join: test1.a = test2.a
3957 TableScan: test1
3958 TableScan: test2
3959 ",
3960 );
3961 assert_optimized_plan_equal!(
3962 plan,
3963 @r"
3964 Projection: t.a, t.r
3965 SubqueryAlias: t
3966 Filter: r > Float64(0.8)
3967 Projection: test1.a AS a, TestScalarUDF() AS r
3968 Inner Join: test1.a = test2.a
3969 TableScan: test1
3970 TableScan: test2
3971 "
3972 )
3973 }
3974
3975 #[test]
3976 fn test_push_down_volatile_table_scan() -> Result<()> {
3977 let table_scan = test_table_scan()?;
3979 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3980 signature: Signature::exact(vec![], Volatility::Volatile),
3981 });
3982 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3983 let plan = LogicalPlanBuilder::from(table_scan)
3984 .project(vec![col("a"), col("b")])?
3985 .filter(expr.gt(lit(0.1)))?
3986 .build()?;
3987
3988 assert_snapshot!(plan,
3989 @r"
3990 Filter: TestScalarUDF() > Float64(0.1)
3991 Projection: test.a, test.b
3992 TableScan: test
3993 ",
3994 );
3995 assert_optimized_plan_equal!(
3996 plan,
3997 @r"
3998 Projection: test.a, test.b
3999 Filter: TestScalarUDF() > Float64(0.1)
4000 TableScan: test
4001 "
4002 )
4003 }
4004
4005 #[test]
4006 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4007 let table_scan = test_table_scan()?;
4009 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4010 signature: Signature::exact(vec![], Volatility::Volatile),
4011 });
4012 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4013 let plan = LogicalPlanBuilder::from(table_scan)
4014 .project(vec![col("a"), col("b")])?
4015 .filter(
4016 expr.gt(lit(0.1))
4017 .and(col("t.a").gt(lit(5)))
4018 .and(col("t.b").gt(lit(10))),
4019 )?
4020 .build()?;
4021
4022 assert_snapshot!(plan,
4023 @r"
4024 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4025 Projection: test.a, test.b
4026 TableScan: test
4027 ",
4028 );
4029 assert_optimized_plan_equal!(
4030 plan,
4031 @r"
4032 Projection: test.a, test.b
4033 Filter: TestScalarUDF() > Float64(0.1)
4034 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4035 "
4036 )
4037 }
4038
4039 #[test]
4040 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4041 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4043 signature: Signature::exact(vec![], Volatility::Volatile),
4044 });
4045 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4046 let plan = table_scan_with_pushdown_provider_builder(
4047 TableProviderFilterPushDown::Unsupported,
4048 vec![],
4049 None,
4050 )?
4051 .project(vec![col("a"), col("b")])?
4052 .filter(
4053 expr.gt(lit(0.1))
4054 .and(col("t.a").gt(lit(5)))
4055 .and(col("t.b").gt(lit(10))),
4056 )?
4057 .build()?;
4058
4059 assert_snapshot!(plan,
4060 @r"
4061 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4062 Projection: a, b
4063 TableScan: test
4064 ",
4065 );
4066 assert_optimized_plan_equal!(
4067 plan,
4068 @r"
4069 Projection: a, b
4070 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4071 TableScan: test
4072 "
4073 )
4074 }
4075
4076 #[test]
4077 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4078 #[derive(Debug, Hash, Eq, PartialEq)]
4080 struct TestUserNode {
4081 schema: DFSchemaRef,
4082 }
4083
4084 impl PartialOrd for TestUserNode {
4085 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4086 None
4087 }
4088 }
4089
4090 impl TestUserNode {
4091 fn new() -> Self {
4092 let schema = Arc::new(
4093 DFSchema::new_with_metadata(
4094 vec![(None, Field::new("a", DataType::Int64, false).into())],
4095 Default::default(),
4096 )
4097 .unwrap(),
4098 );
4099
4100 Self { schema }
4101 }
4102 }
4103
4104 impl UserDefinedLogicalNodeCore for TestUserNode {
4105 fn name(&self) -> &str {
4106 "test_node"
4107 }
4108
4109 fn inputs(&self) -> Vec<&LogicalPlan> {
4110 vec![]
4111 }
4112
4113 fn schema(&self) -> &DFSchemaRef {
4114 &self.schema
4115 }
4116
4117 fn expressions(&self) -> Vec<Expr> {
4118 vec![]
4119 }
4120
4121 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4122 write!(f, "TestUserNode")
4123 }
4124
4125 fn with_exprs_and_inputs(
4126 &self,
4127 exprs: Vec<Expr>,
4128 inputs: Vec<LogicalPlan>,
4129 ) -> Result<Self> {
4130 assert!(exprs.is_empty());
4131 assert!(inputs.is_empty());
4132 Ok(Self {
4133 schema: Arc::clone(&self.schema),
4134 })
4135 }
4136 }
4137
4138 let node = LogicalPlan::Extension(Extension {
4140 node: Arc::new(TestUserNode::new()),
4141 });
4142
4143 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4144
4145 assert_snapshot!(plan,
4147 @r"
4148 Filter: Boolean(false)
4149 TestUserNode
4150 ",
4151 );
4152 assert_optimized_plan_equal!(
4154 plan,
4155 @r"
4156 Filter: Boolean(false)
4157 TestUserNode
4158 "
4159 )
4160 }
4161}