1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31 Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
32 internal_err, plan_err, qualified_name,
33};
34use datafusion_expr::expr::WindowFunction;
35use datafusion_expr::expr_rewriter::replace_col;
36use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
37use datafusion_expr::utils::{
38 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
39};
40use datafusion_expr::{
41 BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
42};
43
44use crate::optimizer::ApplyOrder;
45use crate::simplify_expressions::simplify_predicates;
46use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
47use crate::{OptimizerConfig, OptimizerRule};
48use datafusion_expr::ExpressionPlacement;
49
50#[derive(Default, Debug)]
138pub struct PushDownFilter {}
139
140pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
165 match join_type {
166 JoinType::Inner => (true, true),
167 JoinType::Left => (true, false),
168 JoinType::Right => (false, true),
169 JoinType::Full => (false, false),
170 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
173 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
176 }
177}
178
179pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
189 match join_type {
190 JoinType::Inner => (true, true),
191 JoinType::Left => (false, true),
192 JoinType::Right => (true, false),
193 JoinType::Full => (false, false),
194 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
195 JoinType::LeftAnti => (false, true),
196 JoinType::RightAnti => (true, false),
197 JoinType::LeftMark => (false, true),
198 JoinType::RightMark => (true, false),
199 }
200}
201
202#[derive(Debug)]
205struct ColumnChecker<'a> {
206 left_schema: &'a DFSchema,
208 left_columns: Option<HashSet<Column>>,
210 right_schema: &'a DFSchema,
212 right_columns: Option<HashSet<Column>>,
214}
215
216impl<'a> ColumnChecker<'a> {
217 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
218 Self {
219 left_schema,
220 left_columns: None,
221 right_schema,
222 right_columns: None,
223 }
224 }
225
226 fn is_left_only(&mut self, predicate: &Expr) -> bool {
228 if self.left_columns.is_none() {
229 self.left_columns = Some(schema_columns(self.left_schema));
230 }
231 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
232 }
233
234 fn is_right_only(&mut self, predicate: &Expr) -> bool {
236 if self.right_columns.is_none() {
237 self.right_columns = Some(schema_columns(self.right_schema));
238 }
239 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
240 }
241}
242
243fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
245 schema
246 .iter()
247 .flat_map(|(qualifier, field)| {
248 [
249 Column::new(qualifier.cloned(), field.name()),
250 Column::new_unqualified(field.name()),
252 ]
253 })
254 .collect::<HashSet<_>>()
255}
256
257fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
259 let mut is_evaluate = true;
260 predicate.apply(|expr| match expr {
261 Expr::Column(_)
262 | Expr::Literal(_, _)
263 | Expr::Placeholder(_)
264 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
265 Expr::Exists { .. }
266 | Expr::InSubquery(_)
267 | Expr::SetComparison(_)
268 | Expr::ScalarSubquery(_)
269 | Expr::OuterReferenceColumn(_, _)
270 | Expr::Unnest(_) => {
271 is_evaluate = false;
272 Ok(TreeNodeRecursion::Stop)
273 }
274 Expr::Alias(_)
275 | Expr::BinaryExpr(_)
276 | Expr::Like(_)
277 | Expr::SimilarTo(_)
278 | Expr::Not(_)
279 | Expr::IsNotNull(_)
280 | Expr::IsNull(_)
281 | Expr::IsTrue(_)
282 | Expr::IsFalse(_)
283 | Expr::IsUnknown(_)
284 | Expr::IsNotTrue(_)
285 | Expr::IsNotFalse(_)
286 | Expr::IsNotUnknown(_)
287 | Expr::Negative(_)
288 | Expr::Between(_)
289 | Expr::Case(_)
290 | Expr::Cast(_)
291 | Expr::TryCast(_)
292 | Expr::InList { .. }
293 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
294 #[expect(deprecated)]
296 Expr::AggregateFunction(_)
297 | Expr::WindowFunction(_)
298 | Expr::Wildcard { .. }
299 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
300 })?;
301 Ok(is_evaluate)
302}
303
304fn extract_or_clauses_for_join<'a>(
338 filters: &'a [Expr],
339 schema: &'a DFSchema,
340) -> impl Iterator<Item = Expr> + 'a {
341 let schema_columns = schema_columns(schema);
342
343 filters.iter().filter_map(move |expr| {
345 if let Expr::BinaryExpr(BinaryExpr {
346 left,
347 op: Operator::Or,
348 right,
349 }) = expr
350 {
351 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
352 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
353
354 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
356 return Some(or(left_expr, right_expr));
357 }
358 }
359 None
360 })
361}
362
363fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
375 let mut predicate = None;
376
377 match expr {
378 Expr::BinaryExpr(BinaryExpr {
379 left: l_expr,
380 op: Operator::Or,
381 right: r_expr,
382 }) => {
383 let l_expr = extract_or_clause(l_expr, schema_columns);
384 let r_expr = extract_or_clause(r_expr, schema_columns);
385
386 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
387 predicate = Some(or(l_expr, r_expr));
388 }
389 }
390 Expr::BinaryExpr(BinaryExpr {
391 left: l_expr,
392 op: Operator::And,
393 right: r_expr,
394 }) => {
395 let l_expr = extract_or_clause(l_expr, schema_columns);
396 let r_expr = extract_or_clause(r_expr, schema_columns);
397
398 match (l_expr, r_expr) {
399 (Some(l_expr), Some(r_expr)) => {
400 predicate = Some(and(l_expr, r_expr));
401 }
402 (Some(l_expr), None) => {
403 predicate = Some(l_expr);
404 }
405 (None, Some(r_expr)) => {
406 predicate = Some(r_expr);
407 }
408 (None, None) => {
409 predicate = None;
410 }
411 }
412 }
413 _ => {
414 if has_all_column_refs(expr, schema_columns) {
415 predicate = Some(expr.clone());
416 }
417 }
418 }
419
420 predicate
421}
422
423fn push_down_all_join(
425 predicates: Vec<Expr>,
426 inferred_join_predicates: Vec<Expr>,
427 mut join: Join,
428 on_filter: Vec<Expr>,
429) -> Result<Transformed<LogicalPlan>> {
430 let is_inner_join = join.join_type == JoinType::Inner;
431 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
433
434 let left_schema = join.left.schema();
439 let right_schema = join.right.schema();
440 let mut left_push = vec![];
441 let mut right_push = vec![];
442 let mut keep_predicates = vec![];
443 let mut join_conditions = vec![];
444 let mut checker = ColumnChecker::new(left_schema, right_schema);
445 for predicate in predicates {
446 if left_preserved && checker.is_left_only(&predicate) {
447 left_push.push(predicate);
448 } else if right_preserved && checker.is_right_only(&predicate) {
449 right_push.push(predicate);
450 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
451 join_conditions.push(predicate);
454 } else {
455 keep_predicates.push(predicate);
456 }
457 }
458
459 for predicate in inferred_join_predicates {
461 if checker.is_left_only(&predicate) {
462 left_push.push(predicate);
463 } else if checker.is_right_only(&predicate) {
464 right_push.push(predicate);
465 }
466 }
467
468 let mut on_filter_join_conditions = vec![];
469 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
470
471 if !on_filter.is_empty() {
472 for on in on_filter {
473 if on_left_preserved && checker.is_left_only(&on) {
474 left_push.push(on)
475 } else if on_right_preserved && checker.is_right_only(&on) {
476 right_push.push(on)
477 } else {
478 on_filter_join_conditions.push(on)
479 }
480 }
481 }
482
483 if left_preserved {
486 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
487 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
488 }
489 if right_preserved {
490 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
491 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
492 }
493
494 if on_left_preserved {
497 left_push.extend(extract_or_clauses_for_join(
498 &on_filter_join_conditions,
499 left_schema,
500 ));
501 }
502 if on_right_preserved {
503 right_push.extend(extract_or_clauses_for_join(
504 &on_filter_join_conditions,
505 right_schema,
506 ));
507 }
508
509 if let Some(predicate) = conjunction(left_push) {
510 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
511 }
512 if let Some(predicate) = conjunction(right_push) {
513 join.right =
514 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
515 }
516
517 join_conditions.extend(on_filter_join_conditions);
519 join.filter = conjunction(join_conditions);
520
521 let plan = LogicalPlan::Join(join);
523 let plan = if let Some(predicate) = conjunction(keep_predicates) {
524 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
525 } else {
526 plan
527 };
528 Ok(Transformed::yes(plan))
529}
530
531fn push_down_join(
532 join: Join,
533 parent_predicate: Option<&Expr>,
534) -> Result<Transformed<LogicalPlan>> {
535 let predicates = parent_predicate
537 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
538
539 let on_filters = join
541 .filter
542 .as_ref()
543 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
544
545 let inferred_join_predicates =
547 infer_join_predicates(&join, &predicates, &on_filters)?;
548
549 if on_filters.is_empty()
550 && predicates.is_empty()
551 && inferred_join_predicates.is_empty()
552 {
553 return Ok(Transformed::no(LogicalPlan::Join(join)));
554 }
555
556 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
557}
558
559fn infer_join_predicates(
569 join: &Join,
570 predicates: &[Expr],
571 on_filters: &[Expr],
572) -> Result<Vec<Expr>> {
573 let join_col_keys = join
575 .on
576 .iter()
577 .filter_map(|(l, r)| {
578 let left_col = l.try_as_col()?;
579 let right_col = r.try_as_col()?;
580 Some((left_col, right_col))
581 })
582 .collect::<Vec<_>>();
583
584 let join_type = join.join_type;
585
586 let mut inferred_predicates = InferredPredicates::new(join_type);
587
588 infer_join_predicates_from_predicates(
589 &join_col_keys,
590 predicates,
591 &mut inferred_predicates,
592 )?;
593
594 infer_join_predicates_from_on_filters(
595 &join_col_keys,
596 join_type,
597 on_filters,
598 &mut inferred_predicates,
599 )?;
600
601 Ok(inferred_predicates.predicates)
602}
603
604struct InferredPredicates {
613 predicates: Vec<Expr>,
614 is_inner_join: bool,
615}
616
617impl InferredPredicates {
618 fn new(join_type: JoinType) -> Self {
619 Self {
620 predicates: vec![],
621 is_inner_join: join_type == JoinType::Inner,
622 }
623 }
624
625 fn try_build_predicate(
626 &mut self,
627 predicate: Expr,
628 replace_map: &HashMap<&Column, &Column>,
629 ) -> Result<()> {
630 if self.is_inner_join
631 || matches!(
632 is_restrict_null_predicate(
633 predicate.clone(),
634 replace_map.keys().cloned()
635 ),
636 Ok(true)
637 )
638 {
639 self.predicates.push(replace_col(predicate, replace_map)?);
640 }
641
642 Ok(())
643 }
644}
645
646fn infer_join_predicates_from_predicates(
655 join_col_keys: &[(&Column, &Column)],
656 predicates: &[Expr],
657 inferred_predicates: &mut InferredPredicates,
658) -> Result<()> {
659 infer_join_predicates_impl::<true, true>(
660 join_col_keys,
661 predicates,
662 inferred_predicates,
663 )
664}
665
666fn infer_join_predicates_from_on_filters(
678 join_col_keys: &[(&Column, &Column)],
679 join_type: JoinType,
680 on_filters: &[Expr],
681 inferred_predicates: &mut InferredPredicates,
682) -> Result<()> {
683 match join_type {
684 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
685 JoinType::Inner => infer_join_predicates_impl::<true, true>(
686 join_col_keys,
687 on_filters,
688 inferred_predicates,
689 ),
690 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
691 infer_join_predicates_impl::<true, false>(
692 join_col_keys,
693 on_filters,
694 inferred_predicates,
695 )
696 }
697 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
698 infer_join_predicates_impl::<false, true>(
699 join_col_keys,
700 on_filters,
701 inferred_predicates,
702 )
703 }
704 }
705}
706
707fn infer_join_predicates_impl<
723 const ENABLE_LEFT_TO_RIGHT: bool,
724 const ENABLE_RIGHT_TO_LEFT: bool,
725>(
726 join_col_keys: &[(&Column, &Column)],
727 input_predicates: &[Expr],
728 inferred_predicates: &mut InferredPredicates,
729) -> Result<()> {
730 for predicate in input_predicates {
731 let mut join_cols_to_replace = HashMap::new();
732
733 for &col in &predicate.column_refs() {
734 for (l, r) in join_col_keys.iter() {
735 if ENABLE_LEFT_TO_RIGHT && col == *l {
736 join_cols_to_replace.insert(col, *r);
737 break;
738 }
739 if ENABLE_RIGHT_TO_LEFT && col == *r {
740 join_cols_to_replace.insert(col, *l);
741 break;
742 }
743 }
744 }
745 if join_cols_to_replace.is_empty() {
746 continue;
747 }
748
749 inferred_predicates
750 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
751 }
752 Ok(())
753}
754
755impl OptimizerRule for PushDownFilter {
756 fn name(&self) -> &str {
757 "push_down_filter"
758 }
759
760 fn apply_order(&self) -> Option<ApplyOrder> {
761 Some(ApplyOrder::TopDown)
762 }
763
764 fn supports_rewrite(&self) -> bool {
765 true
766 }
767
768 fn rewrite(
769 &self,
770 plan: LogicalPlan,
771 config: &dyn OptimizerConfig,
772 ) -> Result<Transformed<LogicalPlan>> {
773 let _ = config.options();
774 if let LogicalPlan::Join(join) = plan {
775 return push_down_join(join, None);
776 };
777
778 let plan_schema = Arc::clone(plan.schema());
779
780 let LogicalPlan::Filter(mut filter) = plan else {
781 return Ok(Transformed::no(plan));
782 };
783
784 let predicate = split_conjunction_owned(filter.predicate.clone());
785 let old_predicate_len = predicate.len();
786 let new_predicates = simplify_predicates(predicate)?;
787 if old_predicate_len != new_predicates.len() {
788 let Some(new_predicate) = conjunction(new_predicates) else {
789 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
792 };
793 filter.predicate = new_predicate;
794 }
795
796 if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() {
800 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
801 }
802
803 match Arc::unwrap_or_clone(filter.input) {
804 LogicalPlan::Filter(child_filter) => {
805 let parents_predicates = split_conjunction_owned(filter.predicate);
806
807 let child_predicates = split_conjunction_owned(child_filter.predicate);
809 let new_predicates = parents_predicates
810 .into_iter()
811 .chain(child_predicates)
812 .collect::<IndexSet<_>>()
814 .into_iter()
815 .collect::<Vec<_>>();
816
817 let Some(new_predicate) = conjunction(new_predicates) else {
818 return plan_err!("at least one expression exists");
819 };
820 let new_filter = LogicalPlan::Filter(Filter::try_new(
821 new_predicate,
822 child_filter.input,
823 )?);
824 self.rewrite(new_filter, config)
825 }
826 LogicalPlan::Repartition(repartition) => {
827 let new_filter =
828 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
829 .map(LogicalPlan::Filter)?;
830 insert_below(LogicalPlan::Repartition(repartition), new_filter)
831 }
832 LogicalPlan::Distinct(distinct) => {
833 let new_filter =
834 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
835 .map(LogicalPlan::Filter)?;
836 insert_below(LogicalPlan::Distinct(distinct), new_filter)
837 }
838 LogicalPlan::Sort(sort) => {
839 let new_filter =
840 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
841 .map(LogicalPlan::Filter)?;
842 insert_below(LogicalPlan::Sort(sort), new_filter)
843 }
844 LogicalPlan::SubqueryAlias(subquery_alias) => {
845 let mut replace_map = HashMap::new();
846 for (i, (qualifier, field)) in
847 subquery_alias.input.schema().iter().enumerate()
848 {
849 let (sub_qualifier, sub_field) =
850 subquery_alias.schema.qualified_field(i);
851 replace_map.insert(
852 qualified_name(sub_qualifier, sub_field.name()),
853 Expr::Column(Column::new(qualifier.cloned(), field.name())),
854 );
855 }
856 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
857
858 let new_filter = LogicalPlan::Filter(Filter::try_new(
859 new_predicate,
860 Arc::clone(&subquery_alias.input),
861 )?);
862 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
863 }
864 LogicalPlan::Projection(projection) => {
865 let predicates = split_conjunction_owned(filter.predicate.clone());
866 let (new_projection, keep_predicate) =
867 rewrite_projection(predicates, projection)?;
868 if new_projection.transformed {
869 match keep_predicate {
870 None => Ok(new_projection),
871 Some(keep_predicate) => new_projection.map_data(|child_plan| {
872 Filter::try_new(keep_predicate, Arc::new(child_plan))
873 .map(LogicalPlan::Filter)
874 }),
875 }
876 } else {
877 filter.input = Arc::new(new_projection.data);
878 Ok(Transformed::no(LogicalPlan::Filter(filter)))
879 }
880 }
881 LogicalPlan::Unnest(mut unnest) => {
882 let predicates = split_conjunction_owned(filter.predicate.clone());
883 let mut non_unnest_predicates = vec![];
884 let mut unnest_predicates = vec![];
885 let mut unnest_struct_columns = vec![];
886
887 for idx in &unnest.struct_type_columns {
888 let (sub_qualifier, field) =
889 unnest.input.schema().qualified_field(*idx);
890 let field_name = field.name().clone();
891
892 if let DataType::Struct(children) = field.data_type() {
893 for child in children {
894 let child_name = child.name().clone();
895 unnest_struct_columns.push(Column::new(
896 sub_qualifier.cloned(),
897 format!("{field_name}.{child_name}"),
898 ));
899 }
900 }
901 }
902
903 for predicate in predicates {
904 let mut accum: HashSet<Column> = HashSet::new();
906 expr_to_columns(&predicate, &mut accum)?;
907
908 let contains_list_columns =
909 unnest.list_type_columns.iter().any(|(_, unnest_list)| {
910 accum.contains(&unnest_list.output_column)
911 });
912 let contains_struct_columns =
913 unnest_struct_columns.iter().any(|c| accum.contains(c));
914
915 if contains_list_columns || contains_struct_columns {
916 unnest_predicates.push(predicate);
917 } else {
918 non_unnest_predicates.push(predicate);
919 }
920 }
921
922 if non_unnest_predicates.is_empty() {
925 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
926 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
927 }
928
929 let unnest_input = std::mem::take(&mut unnest.input);
938
939 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
940 conjunction(non_unnest_predicates).unwrap(), unnest_input,
942 )?);
943
944 let unnest_plan =
948 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
949
950 match conjunction(unnest_predicates) {
951 None => Ok(unnest_plan),
952 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
953 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
954 ))),
955 }
956 }
957 LogicalPlan::Union(ref union) => {
958 let mut inputs = Vec::with_capacity(union.inputs.len());
959 for input in &union.inputs {
960 let mut replace_map = HashMap::new();
961 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
962 let (union_qualifier, union_field) =
963 union.schema.qualified_field(i);
964 replace_map.insert(
965 qualified_name(union_qualifier, union_field.name()),
966 Expr::Column(Column::new(qualifier.cloned(), field.name())),
967 );
968 }
969
970 let push_predicate =
971 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
972 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
973 push_predicate,
974 Arc::clone(input),
975 )?)))
976 }
977 Ok(Transformed::yes(LogicalPlan::Union(Union {
978 inputs,
979 schema: Arc::clone(&plan_schema),
980 })))
981 }
982 LogicalPlan::Aggregate(agg) => {
983 let group_expr_columns = agg
985 .group_expr
986 .iter()
987 .map(|e| {
988 let (relation, name) = e.qualified_name();
989 Column::new(relation, name)
990 })
991 .collect::<HashSet<_>>();
992
993 let predicates = split_conjunction_owned(filter.predicate);
994
995 let mut keep_predicates = vec![];
996 let mut push_predicates = vec![];
997 for expr in predicates {
998 let cols = expr.column_refs();
999 if cols.iter().all(|c| group_expr_columns.contains(c)) {
1000 push_predicates.push(expr);
1001 } else {
1002 keep_predicates.push(expr);
1003 }
1004 }
1005
1006 let mut replace_map = HashMap::new();
1010 for expr in &agg.group_expr {
1011 replace_map.insert(expr.schema_name().to_string(), expr.clone());
1012 }
1013 let replaced_push_predicates = push_predicates
1014 .into_iter()
1015 .map(|expr| replace_cols_by_name(expr, &replace_map))
1016 .collect::<Result<Vec<_>>>()?;
1017
1018 let agg_input = Arc::clone(&agg.input);
1019 Transformed::yes(LogicalPlan::Aggregate(agg))
1020 .transform_data(|new_plan| {
1021 if let Some(predicate) = conjunction(replaced_push_predicates) {
1023 let new_filter = make_filter(predicate, agg_input)?;
1024 insert_below(new_plan, new_filter)
1025 } else {
1026 Ok(Transformed::no(new_plan))
1027 }
1028 })?
1029 .map_data(|child_plan| {
1030 if let Some(predicate) = conjunction(keep_predicates) {
1033 make_filter(predicate, Arc::new(child_plan))
1034 } else {
1035 Ok(child_plan)
1036 }
1037 })
1038 }
1039 LogicalPlan::Window(window) => {
1050 let extract_partition_keys = |func: &WindowFunction| {
1056 func.params
1057 .partition_by
1058 .iter()
1059 .map(|c| {
1060 let (relation, name) = c.qualified_name();
1061 Column::new(relation, name)
1062 })
1063 .collect::<HashSet<_>>()
1064 };
1065 let potential_partition_keys = window
1066 .window_expr
1067 .iter()
1068 .map(|e| {
1069 match e {
1070 Expr::WindowFunction(window_func) => {
1071 extract_partition_keys(window_func)
1072 }
1073 Expr::Alias(alias) => {
1074 if let Expr::WindowFunction(window_func) =
1075 alias.expr.as_ref()
1076 {
1077 extract_partition_keys(window_func)
1078 } else {
1079 unreachable!()
1081 }
1082 }
1083 _ => {
1084 unreachable!()
1086 }
1087 }
1088 })
1089 .reduce(|a, b| &a & &b)
1092 .unwrap_or_default();
1093
1094 let predicates = split_conjunction_owned(filter.predicate);
1095 let mut keep_predicates = vec![];
1096 let mut push_predicates = vec![];
1097 for expr in predicates {
1098 let cols = expr.column_refs();
1099 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1100 push_predicates.push(expr);
1101 } else {
1102 keep_predicates.push(expr);
1103 }
1104 }
1105
1106 let window_input = Arc::clone(&window.input);
1115 Transformed::yes(LogicalPlan::Window(window))
1116 .transform_data(|new_plan| {
1117 if let Some(predicate) = conjunction(push_predicates) {
1119 let new_filter = make_filter(predicate, window_input)?;
1120 insert_below(new_plan, new_filter)
1121 } else {
1122 Ok(Transformed::no(new_plan))
1123 }
1124 })?
1125 .map_data(|child_plan| {
1126 if let Some(predicate) = conjunction(keep_predicates) {
1129 make_filter(predicate, Arc::new(child_plan))
1130 } else {
1131 Ok(child_plan)
1132 }
1133 })
1134 }
1135 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1136 LogicalPlan::TableScan(scan) => {
1137 let filter_predicates = split_conjunction(&filter.predicate);
1138
1139 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1140 filter_predicates
1141 .into_iter()
1142 .partition(|pred| pred.is_volatile());
1143
1144 let supported_filters = scan
1146 .source
1147 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1148 assert_eq_or_internal_err!(
1149 non_volatile_filters.len(),
1150 supported_filters.len(),
1151 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1152 supported_filters.len(),
1153 non_volatile_filters.len()
1154 );
1155
1156 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1158
1159 let new_scan_filters = zip
1160 .clone()
1161 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1162 .map(|(pred, _)| pred);
1163
1164 let new_scan_filters: Vec<Expr> = scan
1166 .filters
1167 .iter()
1168 .chain(new_scan_filters)
1169 .unique()
1170 .cloned()
1171 .collect();
1172
1173 let new_predicate: Vec<Expr> = zip
1175 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1176 .map(|(pred, _)| pred)
1177 .chain(volatile_filters)
1178 .cloned()
1179 .collect();
1180
1181 let new_scan = LogicalPlan::TableScan(TableScan {
1182 filters: new_scan_filters,
1183 ..scan
1184 });
1185
1186 Transformed::yes(new_scan).transform_data(|new_scan| {
1187 if let Some(predicate) = conjunction(new_predicate) {
1188 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1189 } else {
1190 Ok(Transformed::no(new_scan))
1191 }
1192 })
1193 }
1194 LogicalPlan::Extension(extension_plan) => {
1195 if extension_plan.node.inputs().is_empty() {
1198 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1199 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1200 }
1201 let prevent_cols =
1202 extension_plan.node.prevent_predicate_push_down_columns();
1203
1204 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1208 .iter()
1209 .map(|expr| {
1210 let cols = expr.column_refs();
1211 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1212 Ok(false) } else {
1214 Ok(true) }
1216 })
1217 .collect::<Result<Vec<_>>>()?;
1218
1219 if predicate_push_or_keep.iter().all(|&x| !x) {
1221 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1222 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1223 }
1224
1225 let mut keep_predicates = vec![];
1227 let mut push_predicates = vec![];
1228 for (push, expr) in predicate_push_or_keep
1229 .into_iter()
1230 .zip(split_conjunction_owned(filter.predicate).into_iter())
1231 {
1232 if !push {
1233 keep_predicates.push(expr);
1234 } else {
1235 push_predicates.push(expr);
1236 }
1237 }
1238
1239 let new_children = match conjunction(push_predicates) {
1240 Some(predicate) => extension_plan
1241 .node
1242 .inputs()
1243 .into_iter()
1244 .map(|child| {
1245 Ok(LogicalPlan::Filter(Filter::try_new(
1246 predicate.clone(),
1247 Arc::new(child.clone()),
1248 )?))
1249 })
1250 .collect::<Result<Vec<_>>>()?,
1251 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1252 };
1253 let child_plan = LogicalPlan::Extension(extension_plan);
1255 let new_extension =
1256 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1257
1258 let new_plan = match conjunction(keep_predicates) {
1259 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1260 predicate,
1261 Arc::new(new_extension),
1262 )?),
1263 None => new_extension,
1264 };
1265 Ok(Transformed::yes(new_plan))
1266 }
1267 child => {
1268 filter.input = Arc::new(child);
1269 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1270 }
1271 }
1272 }
1273}
1274
1275fn rewrite_projection(
1303 predicates: Vec<Expr>,
1304 mut projection: Projection,
1305) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1306 let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection
1313 .schema
1314 .iter()
1315 .zip(projection.expr.iter())
1316 .map(|((qualifier, field), expr)| {
1317 let expr = expr.clone().unalias();
1319
1320 (qualified_name(qualifier, field.name()), expr)
1321 })
1322 .partition(|(_, value)| {
1323 value.is_volatile()
1324 || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes
1325 });
1326
1327 let mut push_predicates = vec![];
1328 let mut keep_predicates = vec![];
1329 for expr in predicates {
1330 if contain(&expr, &non_pushable_map) {
1331 keep_predicates.push(expr);
1332 } else {
1333 push_predicates.push(expr);
1334 }
1335 }
1336
1337 match conjunction(push_predicates) {
1338 Some(expr) => {
1339 let new_filter = LogicalPlan::Filter(Filter::try_new(
1342 replace_cols_by_name(expr, &pushable_map)?,
1343 std::mem::take(&mut projection.input),
1344 )?);
1345
1346 projection.input = Arc::new(new_filter);
1347
1348 Ok((
1349 Transformed::yes(LogicalPlan::Projection(projection)),
1350 conjunction(keep_predicates),
1351 ))
1352 }
1353 None => Ok((
1354 Transformed::no(LogicalPlan::Projection(projection)),
1355 conjunction(keep_predicates),
1356 )),
1357 }
1358}
1359
1360pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1362 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1363}
1364
1365fn insert_below(
1379 plan: LogicalPlan,
1380 new_child: LogicalPlan,
1381) -> Result<Transformed<LogicalPlan>> {
1382 let mut new_child = Some(new_child);
1383 let transformed_plan = plan.map_children(|_child| {
1384 if let Some(new_child) = new_child.take() {
1385 Ok(Transformed::yes(new_child))
1386 } else {
1387 internal_err!("node had more than one input")
1389 }
1390 })?;
1391
1392 assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1394
1395 Ok(transformed_plan)
1396}
1397
1398impl PushDownFilter {
1399 #[expect(missing_docs)]
1400 pub fn new() -> Self {
1401 Self {}
1402 }
1403}
1404
1405pub fn replace_cols_by_name(
1407 e: Expr,
1408 replace_map: &HashMap<String, Expr>,
1409) -> Result<Expr> {
1410 e.transform_up(|expr| {
1411 Ok(if let Expr::Column(c) = &expr {
1412 match replace_map.get(&c.flat_name()) {
1413 Some(new_c) => Transformed::yes(new_c.clone()),
1414 None => Transformed::no(expr),
1415 }
1416 } else {
1417 Transformed::no(expr)
1418 })
1419 })
1420 .data()
1421}
1422
1423fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1425 let mut is_contain = false;
1426 e.apply(|expr| {
1427 Ok(if let Expr::Column(c) = &expr {
1428 match check_map.get(&c.flat_name()) {
1429 Some(_) => {
1430 is_contain = true;
1431 TreeNodeRecursion::Stop
1432 }
1433 None => TreeNodeRecursion::Continue,
1434 }
1435 } else {
1436 TreeNodeRecursion::Continue
1437 })
1438 })
1439 .unwrap();
1440 is_contain
1441}
1442
1443#[cfg(test)]
1444mod tests {
1445 use std::any::Any;
1446 use std::cmp::Ordering;
1447 use std::fmt::{Debug, Formatter};
1448
1449 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1450 use async_trait::async_trait;
1451
1452 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1453 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1454 use datafusion_expr::logical_plan::table_scan;
1455 use datafusion_expr::{
1456 ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1457 ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1458 UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1459 in_subquery, lit,
1460 };
1461
1462 use crate::OptimizerContext;
1463 use crate::assert_optimized_plan_eq_snapshot;
1464 use crate::optimizer::Optimizer;
1465 use crate::simplify_expressions::SimplifyExpressions;
1466 use crate::test::udfs::leaf_udf_expr;
1467 use crate::test::*;
1468 use datafusion_expr::test::function_stub::sum;
1469 use insta::assert_snapshot;
1470
1471 use super::*;
1472
1473 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1474
1475 macro_rules! assert_optimized_plan_equal {
1476 (
1477 $plan:expr,
1478 @ $expected:literal $(,)?
1479 ) => {{
1480 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1481 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1482 assert_optimized_plan_eq_snapshot!(
1483 optimizer_ctx,
1484 rules,
1485 $plan,
1486 @ $expected,
1487 )
1488 }};
1489 }
1490
1491 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1492 (
1493 $plan:expr,
1494 @ $expected:literal $(,)?
1495 ) => {{
1496 let optimizer = Optimizer::with_rules(vec![
1497 Arc::new(SimplifyExpressions::new()),
1498 Arc::new(PushDownFilter::new()),
1499 ]);
1500 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1501 assert_snapshot!(optimized_plan, @ $expected);
1502 Ok::<(), DataFusionError>(())
1503 }};
1504 }
1505
1506 #[test]
1507 fn filter_before_projection() -> Result<()> {
1508 let table_scan = test_table_scan()?;
1509 let plan = LogicalPlanBuilder::from(table_scan)
1510 .project(vec![col("a"), col("b")])?
1511 .filter(col("a").eq(lit(1i64)))?
1512 .build()?;
1513 assert_optimized_plan_equal!(
1515 plan,
1516 @r"
1517 Projection: test.a, test.b
1518 TableScan: test, full_filters=[test.a = Int64(1)]
1519 "
1520 )
1521 }
1522
1523 #[test]
1524 fn filter_after_limit() -> Result<()> {
1525 let table_scan = test_table_scan()?;
1526 let plan = LogicalPlanBuilder::from(table_scan)
1527 .project(vec![col("a"), col("b")])?
1528 .limit(0, Some(10))?
1529 .filter(col("a").eq(lit(1i64)))?
1530 .build()?;
1531 assert_optimized_plan_equal!(
1533 plan,
1534 @r"
1535 Filter: test.a = Int64(1)
1536 Limit: skip=0, fetch=10
1537 Projection: test.a, test.b
1538 TableScan: test
1539 "
1540 )
1541 }
1542
1543 #[test]
1544 fn filter_no_columns() -> Result<()> {
1545 let table_scan = test_table_scan()?;
1546 let plan = LogicalPlanBuilder::from(table_scan)
1547 .filter(lit(0i64).eq(lit(1i64)))?
1548 .build()?;
1549 assert_optimized_plan_equal!(
1550 plan,
1551 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1552 )
1553 }
1554
1555 #[test]
1556 fn filter_jump_2_plans() -> Result<()> {
1557 let table_scan = test_table_scan()?;
1558 let plan = LogicalPlanBuilder::from(table_scan)
1559 .project(vec![col("a"), col("b"), col("c")])?
1560 .project(vec![col("c"), col("b")])?
1561 .filter(col("a").eq(lit(1i64)))?
1562 .build()?;
1563 assert_optimized_plan_equal!(
1565 plan,
1566 @r"
1567 Projection: test.c, test.b
1568 Projection: test.a, test.b, test.c
1569 TableScan: test, full_filters=[test.a = Int64(1)]
1570 "
1571 )
1572 }
1573
1574 #[test]
1575 fn filter_move_agg() -> Result<()> {
1576 let table_scan = test_table_scan()?;
1577 let plan = LogicalPlanBuilder::from(table_scan)
1578 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1579 .filter(col("a").gt(lit(10i64)))?
1580 .build()?;
1581 assert_optimized_plan_equal!(
1583 plan,
1584 @r"
1585 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1586 TableScan: test, full_filters=[test.a > Int64(10)]
1587 "
1588 )
1589 }
1590
1591 #[test]
1593 fn filter_move_agg_special() -> Result<()> {
1594 let schema = Schema::new(vec![
1595 Field::new("$a", DataType::UInt32, false),
1596 Field::new("$b", DataType::UInt32, false),
1597 Field::new("$c", DataType::UInt32, false),
1598 ]);
1599 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1600
1601 let plan = LogicalPlanBuilder::from(table_scan)
1602 .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1603 .filter(col("$a").gt(lit(10i64)))?
1604 .build()?;
1605 assert_optimized_plan_equal!(
1607 plan,
1608 @r"
1609 Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1610 TableScan: test, full_filters=[test.$a > Int64(10)]
1611 "
1612 )
1613 }
1614
1615 #[test]
1616 fn filter_complex_group_by() -> Result<()> {
1617 let table_scan = test_table_scan()?;
1618 let plan = LogicalPlanBuilder::from(table_scan)
1619 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1620 .filter(col("b").gt(lit(10i64)))?
1621 .build()?;
1622 assert_optimized_plan_equal!(
1623 plan,
1624 @r"
1625 Filter: test.b > Int64(10)
1626 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1627 TableScan: test
1628 "
1629 )
1630 }
1631
1632 #[test]
1633 fn push_agg_need_replace_expr() -> Result<()> {
1634 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1635 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1636 .filter(col("test.b + test.a").gt(lit(10i64)))?
1637 .build()?;
1638 assert_optimized_plan_equal!(
1639 plan,
1640 @r"
1641 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1642 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1643 "
1644 )
1645 }
1646
1647 #[test]
1648 fn filter_keep_agg() -> Result<()> {
1649 let table_scan = test_table_scan()?;
1650 let plan = LogicalPlanBuilder::from(table_scan)
1651 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1652 .filter(col("b").gt(lit(10i64)))?
1653 .build()?;
1654 assert_optimized_plan_equal!(
1656 plan,
1657 @r"
1658 Filter: b > Int64(10)
1659 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1660 TableScan: test
1661 "
1662 )
1663 }
1664
1665 #[test]
1667 fn filter_move_window() -> Result<()> {
1668 let table_scan = test_table_scan()?;
1669
1670 let window = Expr::from(WindowFunction::new(
1671 WindowFunctionDefinition::WindowUDF(
1672 datafusion_functions_window::rank::rank_udwf(),
1673 ),
1674 vec![],
1675 ))
1676 .partition_by(vec![col("a"), col("b")])
1677 .order_by(vec![col("c").sort(true, true)])
1678 .build()
1679 .unwrap();
1680
1681 let plan = LogicalPlanBuilder::from(table_scan)
1682 .window(vec![window])?
1683 .filter(col("b").gt(lit(10i64)))?
1684 .build()?;
1685
1686 assert_optimized_plan_equal!(
1687 plan,
1688 @r"
1689 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1690 TableScan: test, full_filters=[test.b > Int64(10)]
1691 "
1692 )
1693 }
1694
1695 #[test]
1697 fn filter_window_special_identifier() -> Result<()> {
1698 let schema = Schema::new(vec![
1699 Field::new("$a", DataType::UInt32, false),
1700 Field::new("$b", DataType::UInt32, false),
1701 Field::new("$c", DataType::UInt32, false),
1702 ]);
1703 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1704
1705 let window = Expr::from(WindowFunction::new(
1706 WindowFunctionDefinition::WindowUDF(
1707 datafusion_functions_window::rank::rank_udwf(),
1708 ),
1709 vec![],
1710 ))
1711 .partition_by(vec![col("$a"), col("$b")])
1712 .order_by(vec![col("$c").sort(true, true)])
1713 .build()
1714 .unwrap();
1715
1716 let plan = LogicalPlanBuilder::from(table_scan)
1717 .window(vec![window])?
1718 .filter(col("$b").gt(lit(10i64)))?
1719 .build()?;
1720
1721 assert_optimized_plan_equal!(
1722 plan,
1723 @r"
1724 WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1725 TableScan: test, full_filters=[test.$b > Int64(10)]
1726 "
1727 )
1728 }
1729
1730 #[test]
1733 fn filter_move_complex_window() -> Result<()> {
1734 let table_scan = test_table_scan()?;
1735
1736 let window = Expr::from(WindowFunction::new(
1737 WindowFunctionDefinition::WindowUDF(
1738 datafusion_functions_window::rank::rank_udwf(),
1739 ),
1740 vec![],
1741 ))
1742 .partition_by(vec![col("a"), col("b")])
1743 .order_by(vec![col("c").sort(true, true)])
1744 .build()
1745 .unwrap();
1746
1747 let plan = LogicalPlanBuilder::from(table_scan)
1748 .window(vec![window])?
1749 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1750 .build()?;
1751
1752 assert_optimized_plan_equal!(
1753 plan,
1754 @r"
1755 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1756 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1757 "
1758 )
1759 }
1760
1761 #[test]
1763 fn filter_move_partial_window() -> Result<()> {
1764 let table_scan = test_table_scan()?;
1765
1766 let window = Expr::from(WindowFunction::new(
1767 WindowFunctionDefinition::WindowUDF(
1768 datafusion_functions_window::rank::rank_udwf(),
1769 ),
1770 vec![],
1771 ))
1772 .partition_by(vec![col("a")])
1773 .order_by(vec![col("c").sort(true, true)])
1774 .build()
1775 .unwrap();
1776
1777 let plan = LogicalPlanBuilder::from(table_scan)
1778 .window(vec![window])?
1779 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1780 .build()?;
1781
1782 assert_optimized_plan_equal!(
1783 plan,
1784 @r"
1785 Filter: test.b = Int64(1)
1786 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1787 TableScan: test, full_filters=[test.a > Int64(10)]
1788 "
1789 )
1790 }
1791
1792 #[test]
1795 fn filter_expression_keep_window() -> Result<()> {
1796 let table_scan = test_table_scan()?;
1797
1798 let window = Expr::from(WindowFunction::new(
1799 WindowFunctionDefinition::WindowUDF(
1800 datafusion_functions_window::rank::rank_udwf(),
1801 ),
1802 vec![],
1803 ))
1804 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1806 .build()
1807 .unwrap();
1808
1809 let plan = LogicalPlanBuilder::from(table_scan)
1810 .window(vec![window])?
1811 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1814 .build()?;
1815
1816 assert_optimized_plan_equal!(
1817 plan,
1818 @r"
1819 Filter: test.a + test.b > Int64(10)
1820 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1821 TableScan: test
1822 "
1823 )
1824 }
1825
1826 #[test]
1828 fn filter_order_keep_window() -> Result<()> {
1829 let table_scan = test_table_scan()?;
1830
1831 let window = Expr::from(WindowFunction::new(
1832 WindowFunctionDefinition::WindowUDF(
1833 datafusion_functions_window::rank::rank_udwf(),
1834 ),
1835 vec![],
1836 ))
1837 .partition_by(vec![col("a")])
1838 .order_by(vec![col("c").sort(true, true)])
1839 .build()
1840 .unwrap();
1841
1842 let plan = LogicalPlanBuilder::from(table_scan)
1843 .window(vec![window])?
1844 .filter(col("c").gt(lit(10i64)))?
1845 .build()?;
1846
1847 assert_optimized_plan_equal!(
1848 plan,
1849 @r"
1850 Filter: test.c > Int64(10)
1851 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1852 TableScan: test
1853 "
1854 )
1855 }
1856
1857 #[test]
1860 fn filter_multiple_windows_common_partitions() -> Result<()> {
1861 let table_scan = test_table_scan()?;
1862
1863 let window1 = Expr::from(WindowFunction::new(
1864 WindowFunctionDefinition::WindowUDF(
1865 datafusion_functions_window::rank::rank_udwf(),
1866 ),
1867 vec![],
1868 ))
1869 .partition_by(vec![col("a")])
1870 .order_by(vec![col("c").sort(true, true)])
1871 .build()
1872 .unwrap();
1873
1874 let window2 = Expr::from(WindowFunction::new(
1875 WindowFunctionDefinition::WindowUDF(
1876 datafusion_functions_window::rank::rank_udwf(),
1877 ),
1878 vec![],
1879 ))
1880 .partition_by(vec![col("b"), col("a")])
1881 .order_by(vec![col("c").sort(true, true)])
1882 .build()
1883 .unwrap();
1884
1885 let plan = LogicalPlanBuilder::from(table_scan)
1886 .window(vec![window1, window2])?
1887 .filter(col("a").gt(lit(10i64)))? .build()?;
1889
1890 assert_optimized_plan_equal!(
1891 plan,
1892 @r"
1893 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1894 TableScan: test, full_filters=[test.a > Int64(10)]
1895 "
1896 )
1897 }
1898
1899 #[test]
1902 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1903 let table_scan = test_table_scan()?;
1904
1905 let window1 = Expr::from(WindowFunction::new(
1906 WindowFunctionDefinition::WindowUDF(
1907 datafusion_functions_window::rank::rank_udwf(),
1908 ),
1909 vec![],
1910 ))
1911 .partition_by(vec![col("a")])
1912 .order_by(vec![col("c").sort(true, true)])
1913 .build()
1914 .unwrap();
1915
1916 let window2 = Expr::from(WindowFunction::new(
1917 WindowFunctionDefinition::WindowUDF(
1918 datafusion_functions_window::rank::rank_udwf(),
1919 ),
1920 vec![],
1921 ))
1922 .partition_by(vec![col("b"), col("a")])
1923 .order_by(vec![col("c").sort(true, true)])
1924 .build()
1925 .unwrap();
1926
1927 let plan = LogicalPlanBuilder::from(table_scan)
1928 .window(vec![window1, window2])?
1929 .filter(col("b").gt(lit(10i64)))? .build()?;
1931
1932 assert_optimized_plan_equal!(
1933 plan,
1934 @r"
1935 Filter: test.b > Int64(10)
1936 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1937 TableScan: test
1938 "
1939 )
1940 }
1941
1942 #[test]
1944 fn alias() -> Result<()> {
1945 let table_scan = test_table_scan()?;
1946 let plan = LogicalPlanBuilder::from(table_scan)
1947 .project(vec![col("a").alias("b"), col("c")])?
1948 .filter(col("b").eq(lit(1i64)))?
1949 .build()?;
1950 assert_optimized_plan_equal!(
1952 plan,
1953 @r"
1954 Projection: test.a AS b, test.c
1955 TableScan: test, full_filters=[test.a = Int64(1)]
1956 "
1957 )
1958 }
1959
1960 fn add(left: Expr, right: Expr) -> Expr {
1961 Expr::BinaryExpr(BinaryExpr::new(
1962 Box::new(left),
1963 Operator::Plus,
1964 Box::new(right),
1965 ))
1966 }
1967
1968 fn multiply(left: Expr, right: Expr) -> Expr {
1969 Expr::BinaryExpr(BinaryExpr::new(
1970 Box::new(left),
1971 Operator::Multiply,
1972 Box::new(right),
1973 ))
1974 }
1975
1976 #[test]
1978 fn complex_expression() -> Result<()> {
1979 let table_scan = test_table_scan()?;
1980 let plan = LogicalPlanBuilder::from(table_scan)
1981 .project(vec![
1982 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1983 col("c"),
1984 ])?
1985 .filter(col("b").eq(lit(1i64)))?
1986 .build()?;
1987
1988 assert_snapshot!(plan,
1990 @r"
1991 Filter: b = Int64(1)
1992 Projection: test.a * Int32(2) + test.c AS b, test.c
1993 TableScan: test
1994 ",
1995 );
1996 assert_optimized_plan_equal!(
1998 plan,
1999 @r"
2000 Projection: test.a * Int32(2) + test.c AS b, test.c
2001 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
2002 "
2003 )
2004 }
2005
2006 #[test]
2008 fn complex_plan() -> Result<()> {
2009 let table_scan = test_table_scan()?;
2010 let plan = LogicalPlanBuilder::from(table_scan)
2011 .project(vec![
2012 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2013 col("c"),
2014 ])?
2015 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2017 .filter(col("a").eq(lit(1i64)))?
2018 .build()?;
2019
2020 assert_snapshot!(plan,
2022 @r"
2023 Filter: a = Int64(1)
2024 Projection: b * Int32(3) AS a, test.c
2025 Projection: test.a * Int32(2) + test.c AS b, test.c
2026 TableScan: test
2027 ",
2028 );
2029 assert_optimized_plan_equal!(
2031 plan,
2032 @r"
2033 Projection: b * Int32(3) AS a, test.c
2034 Projection: test.a * Int32(2) + test.c AS b, test.c
2035 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2036 "
2037 )
2038 }
2039
2040 #[derive(Debug, PartialEq, Eq, Hash)]
2041 struct NoopPlan {
2042 input: Vec<LogicalPlan>,
2043 schema: DFSchemaRef,
2044 }
2045
2046 impl PartialOrd for NoopPlan {
2048 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2049 self.input
2050 .partial_cmp(&other.input)
2051 .filter(|cmp| *cmp != Ordering::Equal || self == other)
2053 }
2054 }
2055
2056 impl UserDefinedLogicalNodeCore for NoopPlan {
2057 fn name(&self) -> &str {
2058 "NoopPlan"
2059 }
2060
2061 fn inputs(&self) -> Vec<&LogicalPlan> {
2062 self.input.iter().collect()
2063 }
2064
2065 fn schema(&self) -> &DFSchemaRef {
2066 &self.schema
2067 }
2068
2069 fn expressions(&self) -> Vec<Expr> {
2070 self.input
2071 .iter()
2072 .flat_map(|child| child.expressions())
2073 .collect()
2074 }
2075
2076 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2077 HashSet::from_iter(vec!["c".to_string()])
2078 }
2079
2080 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2081 write!(f, "NoopPlan")
2082 }
2083
2084 fn with_exprs_and_inputs(
2085 &self,
2086 _exprs: Vec<Expr>,
2087 inputs: Vec<LogicalPlan>,
2088 ) -> Result<Self> {
2089 Ok(Self {
2090 input: inputs,
2091 schema: Arc::clone(&self.schema),
2092 })
2093 }
2094
2095 fn supports_limit_pushdown(&self) -> bool {
2096 false }
2098 }
2099
2100 #[test]
2101 fn user_defined_plan() -> Result<()> {
2102 let table_scan = test_table_scan()?;
2103
2104 let custom_plan = LogicalPlan::Extension(Extension {
2105 node: Arc::new(NoopPlan {
2106 input: vec![table_scan.clone()],
2107 schema: Arc::clone(table_scan.schema()),
2108 }),
2109 });
2110 let plan = LogicalPlanBuilder::from(custom_plan)
2111 .filter(col("a").eq(lit(1i64)))?
2112 .build()?;
2113
2114 assert_optimized_plan_equal!(
2116 plan,
2117 @r"
2118 NoopPlan
2119 TableScan: test, full_filters=[test.a = Int64(1)]
2120 "
2121 )?;
2122
2123 let custom_plan = LogicalPlan::Extension(Extension {
2124 node: Arc::new(NoopPlan {
2125 input: vec![table_scan.clone()],
2126 schema: Arc::clone(table_scan.schema()),
2127 }),
2128 });
2129 let plan = LogicalPlanBuilder::from(custom_plan)
2130 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2131 .build()?;
2132
2133 assert_optimized_plan_equal!(
2135 plan,
2136 @r"
2137 Filter: test.c = Int64(2)
2138 NoopPlan
2139 TableScan: test, full_filters=[test.a = Int64(1)]
2140 "
2141 )?;
2142
2143 let custom_plan = LogicalPlan::Extension(Extension {
2144 node: Arc::new(NoopPlan {
2145 input: vec![table_scan.clone(), table_scan.clone()],
2146 schema: Arc::clone(table_scan.schema()),
2147 }),
2148 });
2149 let plan = LogicalPlanBuilder::from(custom_plan)
2150 .filter(col("a").eq(lit(1i64)))?
2151 .build()?;
2152
2153 assert_optimized_plan_equal!(
2155 plan,
2156 @r"
2157 NoopPlan
2158 TableScan: test, full_filters=[test.a = Int64(1)]
2159 TableScan: test, full_filters=[test.a = Int64(1)]
2160 "
2161 )?;
2162
2163 let custom_plan = LogicalPlan::Extension(Extension {
2164 node: Arc::new(NoopPlan {
2165 input: vec![table_scan.clone(), table_scan.clone()],
2166 schema: Arc::clone(table_scan.schema()),
2167 }),
2168 });
2169 let plan = LogicalPlanBuilder::from(custom_plan)
2170 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2171 .build()?;
2172
2173 assert_optimized_plan_equal!(
2175 plan,
2176 @r"
2177 Filter: test.c = Int64(2)
2178 NoopPlan
2179 TableScan: test, full_filters=[test.a = Int64(1)]
2180 TableScan: test, full_filters=[test.a = Int64(1)]
2181 "
2182 )
2183 }
2184
2185 #[test]
2188 fn multi_filter() -> Result<()> {
2189 let table_scan = test_table_scan()?;
2191 let plan = LogicalPlanBuilder::from(table_scan)
2192 .project(vec![col("a").alias("b"), col("c")])?
2193 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2194 .filter(col("b").gt(lit(10i64)))?
2195 .filter(col("sum(test.c)").gt(lit(10i64)))?
2196 .build()?;
2197
2198 assert_snapshot!(plan,
2200 @r"
2201 Filter: sum(test.c) > Int64(10)
2202 Filter: b > Int64(10)
2203 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2204 Projection: test.a AS b, test.c
2205 TableScan: test
2206 ",
2207 );
2208 assert_optimized_plan_equal!(
2210 plan,
2211 @r"
2212 Filter: sum(test.c) > Int64(10)
2213 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2214 Projection: test.a AS b, test.c
2215 TableScan: test, full_filters=[test.a > Int64(10)]
2216 "
2217 )
2218 }
2219
2220 #[test]
2223 fn split_filter() -> Result<()> {
2224 let table_scan = test_table_scan()?;
2226 let plan = LogicalPlanBuilder::from(table_scan)
2227 .project(vec![col("a").alias("b"), col("c")])?
2228 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2229 .filter(and(
2230 col("sum(test.c)").gt(lit(10i64)),
2231 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2232 ))?
2233 .build()?;
2234
2235 assert_snapshot!(plan,
2237 @r"
2238 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2239 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2240 Projection: test.a AS b, test.c
2241 TableScan: test
2242 ",
2243 );
2244 assert_optimized_plan_equal!(
2246 plan,
2247 @r"
2248 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2249 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2250 Projection: test.a AS b, test.c
2251 TableScan: test, full_filters=[test.a > Int64(10)]
2252 "
2253 )
2254 }
2255
2256 #[test]
2258 fn double_limit() -> Result<()> {
2259 let table_scan = test_table_scan()?;
2260 let plan = LogicalPlanBuilder::from(table_scan)
2261 .project(vec![col("a"), col("b")])?
2262 .limit(0, Some(20))?
2263 .limit(0, Some(10))?
2264 .project(vec![col("a"), col("b")])?
2265 .filter(col("a").eq(lit(1i64)))?
2266 .build()?;
2267 assert_optimized_plan_equal!(
2269 plan,
2270 @r"
2271 Projection: test.a, test.b
2272 Filter: test.a = Int64(1)
2273 Limit: skip=0, fetch=10
2274 Limit: skip=0, fetch=20
2275 Projection: test.a, test.b
2276 TableScan: test
2277 "
2278 )
2279 }
2280
2281 #[test]
2282 fn union_all() -> Result<()> {
2283 let table_scan = test_table_scan()?;
2284 let table_scan2 = test_table_scan_with_name("test2")?;
2285 let plan = LogicalPlanBuilder::from(table_scan)
2286 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2287 .filter(col("a").eq(lit(1i64)))?
2288 .build()?;
2289 assert_optimized_plan_equal!(
2291 plan,
2292 @r"
2293 Union
2294 TableScan: test, full_filters=[test.a = Int64(1)]
2295 TableScan: test2, full_filters=[test2.a = Int64(1)]
2296 "
2297 )
2298 }
2299
2300 #[test]
2301 fn union_all_on_projection() -> Result<()> {
2302 let table_scan = test_table_scan()?;
2303 let table = LogicalPlanBuilder::from(table_scan)
2304 .project(vec![col("a").alias("b")])?
2305 .alias("test2")?;
2306
2307 let plan = table
2308 .clone()
2309 .union(table.build()?)?
2310 .filter(col("b").eq(lit(1i64)))?
2311 .build()?;
2312
2313 assert_optimized_plan_equal!(
2315 plan,
2316 @r"
2317 Union
2318 SubqueryAlias: test2
2319 Projection: test.a AS b
2320 TableScan: test, full_filters=[test.a = Int64(1)]
2321 SubqueryAlias: test2
2322 Projection: test.a AS b
2323 TableScan: test, full_filters=[test.a = Int64(1)]
2324 "
2325 )
2326 }
2327
2328 #[test]
2329 fn test_union_different_schema() -> Result<()> {
2330 let left = LogicalPlanBuilder::from(test_table_scan()?)
2331 .project(vec![col("a"), col("b"), col("c")])?
2332 .build()?;
2333
2334 let schema = Schema::new(vec![
2335 Field::new("d", DataType::UInt32, false),
2336 Field::new("e", DataType::UInt32, false),
2337 Field::new("f", DataType::UInt32, false),
2338 ]);
2339 let right = table_scan(Some("test1"), &schema, None)?
2340 .project(vec![col("d"), col("e"), col("f")])?
2341 .build()?;
2342 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2343 let plan = LogicalPlanBuilder::from(left)
2344 .cross_join(right)?
2345 .project(vec![col("test.a"), col("test1.d")])?
2346 .filter(filter)?
2347 .build()?;
2348
2349 assert_optimized_plan_equal!(
2350 plan,
2351 @r"
2352 Projection: test.a, test1.d
2353 Cross Join:
2354 Projection: test.a, test.b, test.c
2355 TableScan: test, full_filters=[test.a = Int32(1)]
2356 Projection: test1.d, test1.e, test1.f
2357 TableScan: test1, full_filters=[test1.d > Int32(2)]
2358 "
2359 )
2360 }
2361
2362 #[test]
2363 fn test_project_same_name_different_qualifier() -> Result<()> {
2364 let table_scan = test_table_scan()?;
2365 let left = LogicalPlanBuilder::from(table_scan)
2366 .project(vec![col("a"), col("b"), col("c")])?
2367 .build()?;
2368 let right_table_scan = test_table_scan_with_name("test1")?;
2369 let right = LogicalPlanBuilder::from(right_table_scan)
2370 .project(vec![col("a"), col("b"), col("c")])?
2371 .build()?;
2372 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2373 let plan = LogicalPlanBuilder::from(left)
2374 .cross_join(right)?
2375 .project(vec![col("test.a"), col("test1.a")])?
2376 .filter(filter)?
2377 .build()?;
2378
2379 assert_optimized_plan_equal!(
2380 plan,
2381 @r"
2382 Projection: test.a, test1.a
2383 Cross Join:
2384 Projection: test.a, test.b, test.c
2385 TableScan: test, full_filters=[test.a = Int32(1)]
2386 Projection: test1.a, test1.b, test1.c
2387 TableScan: test1, full_filters=[test1.a > Int32(2)]
2388 "
2389 )
2390 }
2391
2392 #[test]
2394 fn filter_2_breaks_limits() -> Result<()> {
2395 let table_scan = test_table_scan()?;
2396 let plan = LogicalPlanBuilder::from(table_scan)
2397 .project(vec![col("a")])?
2398 .filter(col("a").lt_eq(lit(1i64)))?
2399 .limit(0, Some(1))?
2400 .project(vec![col("a")])?
2401 .filter(col("a").gt_eq(lit(1i64)))?
2402 .build()?;
2403 assert_snapshot!(plan,
2407 @r"
2408 Filter: test.a >= Int64(1)
2409 Projection: test.a
2410 Limit: skip=0, fetch=1
2411 Filter: test.a <= Int64(1)
2412 Projection: test.a
2413 TableScan: test
2414 ",
2415 );
2416 assert_optimized_plan_equal!(
2417 plan,
2418 @r"
2419 Projection: test.a
2420 Filter: test.a >= Int64(1)
2421 Limit: skip=0, fetch=1
2422 Projection: test.a
2423 TableScan: test, full_filters=[test.a <= Int64(1)]
2424 "
2425 )
2426 }
2427
2428 #[test]
2430 fn two_filters_on_same_depth() -> Result<()> {
2431 let table_scan = test_table_scan()?;
2432 let plan = LogicalPlanBuilder::from(table_scan)
2433 .limit(0, Some(1))?
2434 .filter(col("a").lt_eq(lit(1i64)))?
2435 .filter(col("a").gt_eq(lit(1i64)))?
2436 .project(vec![col("a")])?
2437 .build()?;
2438
2439 assert_snapshot!(plan,
2441 @r"
2442 Projection: test.a
2443 Filter: test.a >= Int64(1)
2444 Filter: test.a <= Int64(1)
2445 Limit: skip=0, fetch=1
2446 TableScan: test
2447 ",
2448 );
2449 assert_optimized_plan_equal!(
2450 plan,
2451 @r"
2452 Projection: test.a
2453 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2454 Limit: skip=0, fetch=1
2455 TableScan: test
2456 "
2457 )
2458 }
2459
2460 #[test]
2463 fn filters_user_defined_node() -> Result<()> {
2464 let table_scan = test_table_scan()?;
2465 let plan = LogicalPlanBuilder::from(table_scan)
2466 .filter(col("a").lt_eq(lit(1i64)))?
2467 .build()?;
2468
2469 let plan = user_defined::new(plan);
2470
2471 assert_snapshot!(plan,
2473 @r"
2474 TestUserDefined
2475 Filter: test.a <= Int64(1)
2476 TableScan: test
2477 ",
2478 );
2479 assert_optimized_plan_equal!(
2480 plan,
2481 @r"
2482 TestUserDefined
2483 TableScan: test, full_filters=[test.a <= Int64(1)]
2484 "
2485 )
2486 }
2487
2488 #[test]
2490 fn filter_on_join_on_common_independent() -> Result<()> {
2491 let table_scan = test_table_scan()?;
2492 let left = LogicalPlanBuilder::from(table_scan).build()?;
2493 let right_table_scan = test_table_scan_with_name("test2")?;
2494 let right = LogicalPlanBuilder::from(right_table_scan)
2495 .project(vec![col("a")])?
2496 .build()?;
2497 let plan = LogicalPlanBuilder::from(left)
2498 .join(
2499 right,
2500 JoinType::Inner,
2501 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2502 None,
2503 )?
2504 .filter(col("test.a").lt_eq(lit(1i64)))?
2505 .build()?;
2506
2507 assert_snapshot!(plan,
2509 @r"
2510 Filter: test.a <= Int64(1)
2511 Inner Join: test.a = test2.a
2512 TableScan: test
2513 Projection: test2.a
2514 TableScan: test2
2515 ",
2516 );
2517 assert_optimized_plan_equal!(
2519 plan,
2520 @r"
2521 Inner Join: test.a = test2.a
2522 TableScan: test, full_filters=[test.a <= Int64(1)]
2523 Projection: test2.a
2524 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2525 "
2526 )
2527 }
2528
2529 #[test]
2531 fn filter_using_join_on_common_independent() -> Result<()> {
2532 let table_scan = test_table_scan()?;
2533 let left = LogicalPlanBuilder::from(table_scan).build()?;
2534 let right_table_scan = test_table_scan_with_name("test2")?;
2535 let right = LogicalPlanBuilder::from(right_table_scan)
2536 .project(vec![col("a")])?
2537 .build()?;
2538 let plan = LogicalPlanBuilder::from(left)
2539 .join_using(
2540 right,
2541 JoinType::Inner,
2542 vec![Column::from_name("a".to_string())],
2543 )?
2544 .filter(col("a").lt_eq(lit(1i64)))?
2545 .build()?;
2546
2547 assert_snapshot!(plan,
2549 @r"
2550 Filter: test.a <= Int64(1)
2551 Inner Join: Using test.a = test2.a
2552 TableScan: test
2553 Projection: test2.a
2554 TableScan: test2
2555 ",
2556 );
2557 assert_optimized_plan_equal!(
2559 plan,
2560 @r"
2561 Inner Join: Using test.a = test2.a
2562 TableScan: test, full_filters=[test.a <= Int64(1)]
2563 Projection: test2.a
2564 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2565 "
2566 )
2567 }
2568
2569 #[test]
2571 fn filter_join_on_common_dependent() -> Result<()> {
2572 let table_scan = test_table_scan()?;
2573 let left = LogicalPlanBuilder::from(table_scan)
2574 .project(vec![col("a"), col("c")])?
2575 .build()?;
2576 let right_table_scan = test_table_scan_with_name("test2")?;
2577 let right = LogicalPlanBuilder::from(right_table_scan)
2578 .project(vec![col("a"), col("b")])?
2579 .build()?;
2580 let plan = LogicalPlanBuilder::from(left)
2581 .join(
2582 right,
2583 JoinType::Inner,
2584 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2585 None,
2586 )?
2587 .filter(col("c").lt_eq(col("b")))?
2588 .build()?;
2589
2590 assert_snapshot!(plan,
2592 @r"
2593 Filter: test.c <= test2.b
2594 Inner Join: test.a = test2.a
2595 Projection: test.a, test.c
2596 TableScan: test
2597 Projection: test2.a, test2.b
2598 TableScan: test2
2599 ",
2600 );
2601 assert_optimized_plan_equal!(
2603 plan,
2604 @r"
2605 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2606 Projection: test.a, test.c
2607 TableScan: test
2608 Projection: test2.a, test2.b
2609 TableScan: test2
2610 "
2611 )
2612 }
2613
2614 #[test]
2616 fn filter_join_on_one_side() -> Result<()> {
2617 let table_scan = test_table_scan()?;
2618 let left = LogicalPlanBuilder::from(table_scan)
2619 .project(vec![col("a"), col("b")])?
2620 .build()?;
2621 let table_scan_right = test_table_scan_with_name("test2")?;
2622 let right = LogicalPlanBuilder::from(table_scan_right)
2623 .project(vec![col("a"), col("c")])?
2624 .build()?;
2625
2626 let plan = LogicalPlanBuilder::from(left)
2627 .join(
2628 right,
2629 JoinType::Inner,
2630 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2631 None,
2632 )?
2633 .filter(col("b").lt_eq(lit(1i64)))?
2634 .build()?;
2635
2636 assert_snapshot!(plan,
2638 @r"
2639 Filter: test.b <= Int64(1)
2640 Inner Join: test.a = test2.a
2641 Projection: test.a, test.b
2642 TableScan: test
2643 Projection: test2.a, test2.c
2644 TableScan: test2
2645 ",
2646 );
2647 assert_optimized_plan_equal!(
2648 plan,
2649 @r"
2650 Inner Join: test.a = test2.a
2651 Projection: test.a, test.b
2652 TableScan: test, full_filters=[test.b <= Int64(1)]
2653 Projection: test2.a, test2.c
2654 TableScan: test2
2655 "
2656 )
2657 }
2658
2659 #[test]
2662 fn filter_using_left_join() -> Result<()> {
2663 let table_scan = test_table_scan()?;
2664 let left = LogicalPlanBuilder::from(table_scan).build()?;
2665 let right_table_scan = test_table_scan_with_name("test2")?;
2666 let right = LogicalPlanBuilder::from(right_table_scan)
2667 .project(vec![col("a")])?
2668 .build()?;
2669 let plan = LogicalPlanBuilder::from(left)
2670 .join_using(
2671 right,
2672 JoinType::Left,
2673 vec![Column::from_name("a".to_string())],
2674 )?
2675 .filter(col("test2.a").lt_eq(lit(1i64)))?
2676 .build()?;
2677
2678 assert_snapshot!(plan,
2680 @r"
2681 Filter: test2.a <= Int64(1)
2682 Left Join: Using test.a = test2.a
2683 TableScan: test
2684 Projection: test2.a
2685 TableScan: test2
2686 ",
2687 );
2688 assert_optimized_plan_equal!(
2690 plan,
2691 @r"
2692 Filter: test2.a <= Int64(1)
2693 Left Join: Using test.a = test2.a
2694 TableScan: test, full_filters=[test.a <= Int64(1)]
2695 Projection: test2.a
2696 TableScan: test2
2697 "
2698 )
2699 }
2700
2701 #[test]
2703 fn filter_using_right_join() -> Result<()> {
2704 let table_scan = test_table_scan()?;
2705 let left = LogicalPlanBuilder::from(table_scan).build()?;
2706 let right_table_scan = test_table_scan_with_name("test2")?;
2707 let right = LogicalPlanBuilder::from(right_table_scan)
2708 .project(vec![col("a")])?
2709 .build()?;
2710 let plan = LogicalPlanBuilder::from(left)
2711 .join_using(
2712 right,
2713 JoinType::Right,
2714 vec![Column::from_name("a".to_string())],
2715 )?
2716 .filter(col("test.a").lt_eq(lit(1i64)))?
2717 .build()?;
2718
2719 assert_snapshot!(plan,
2721 @r"
2722 Filter: test.a <= Int64(1)
2723 Right Join: Using test.a = test2.a
2724 TableScan: test
2725 Projection: test2.a
2726 TableScan: test2
2727 ",
2728 );
2729 assert_optimized_plan_equal!(
2731 plan,
2732 @r"
2733 Filter: test.a <= Int64(1)
2734 Right Join: Using test.a = test2.a
2735 TableScan: test
2736 Projection: test2.a
2737 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2738 "
2739 )
2740 }
2741
2742 #[test]
2744 fn filter_using_left_join_on_common() -> Result<()> {
2745 let table_scan = test_table_scan()?;
2746 let left = LogicalPlanBuilder::from(table_scan).build()?;
2747 let right_table_scan = test_table_scan_with_name("test2")?;
2748 let right = LogicalPlanBuilder::from(right_table_scan)
2749 .project(vec![col("a")])?
2750 .build()?;
2751 let plan = LogicalPlanBuilder::from(left)
2752 .join_using(
2753 right,
2754 JoinType::Left,
2755 vec![Column::from_name("a".to_string())],
2756 )?
2757 .filter(col("a").lt_eq(lit(1i64)))?
2758 .build()?;
2759
2760 assert_snapshot!(plan,
2762 @r"
2763 Filter: test.a <= Int64(1)
2764 Left Join: Using test.a = test2.a
2765 TableScan: test
2766 Projection: test2.a
2767 TableScan: test2
2768 ",
2769 );
2770 assert_optimized_plan_equal!(
2772 plan,
2773 @r"
2774 Left Join: Using test.a = test2.a
2775 TableScan: test, full_filters=[test.a <= Int64(1)]
2776 Projection: test2.a
2777 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2778 "
2779 )
2780 }
2781
2782 #[test]
2784 fn filter_using_right_join_on_common() -> Result<()> {
2785 let table_scan = test_table_scan()?;
2786 let left = LogicalPlanBuilder::from(table_scan).build()?;
2787 let right_table_scan = test_table_scan_with_name("test2")?;
2788 let right = LogicalPlanBuilder::from(right_table_scan)
2789 .project(vec![col("a")])?
2790 .build()?;
2791 let plan = LogicalPlanBuilder::from(left)
2792 .join_using(
2793 right,
2794 JoinType::Right,
2795 vec![Column::from_name("a".to_string())],
2796 )?
2797 .filter(col("test2.a").lt_eq(lit(1i64)))?
2798 .build()?;
2799
2800 assert_snapshot!(plan,
2802 @r"
2803 Filter: test2.a <= Int64(1)
2804 Right Join: Using test.a = test2.a
2805 TableScan: test
2806 Projection: test2.a
2807 TableScan: test2
2808 ",
2809 );
2810 assert_optimized_plan_equal!(
2812 plan,
2813 @r"
2814 Right Join: Using test.a = test2.a
2815 TableScan: test, full_filters=[test.a <= Int64(1)]
2816 Projection: test2.a
2817 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2818 "
2819 )
2820 }
2821
2822 #[test]
2824 fn join_on_with_filter() -> Result<()> {
2825 let table_scan = test_table_scan()?;
2826 let left = LogicalPlanBuilder::from(table_scan)
2827 .project(vec![col("a"), col("b"), col("c")])?
2828 .build()?;
2829 let right_table_scan = test_table_scan_with_name("test2")?;
2830 let right = LogicalPlanBuilder::from(right_table_scan)
2831 .project(vec![col("a"), col("b"), col("c")])?
2832 .build()?;
2833 let filter = col("test.c")
2834 .gt(lit(1u32))
2835 .and(col("test.b").lt(col("test2.b")))
2836 .and(col("test2.c").gt(lit(4u32)));
2837 let plan = LogicalPlanBuilder::from(left)
2838 .join(
2839 right,
2840 JoinType::Inner,
2841 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2842 Some(filter),
2843 )?
2844 .build()?;
2845
2846 assert_snapshot!(plan,
2848 @r"
2849 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2850 Projection: test.a, test.b, test.c
2851 TableScan: test
2852 Projection: test2.a, test2.b, test2.c
2853 TableScan: test2
2854 ",
2855 );
2856 assert_optimized_plan_equal!(
2857 plan,
2858 @r"
2859 Inner Join: test.a = test2.a Filter: test.b < test2.b
2860 Projection: test.a, test.b, test.c
2861 TableScan: test, full_filters=[test.c > UInt32(1)]
2862 Projection: test2.a, test2.b, test2.c
2863 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2864 "
2865 )
2866 }
2867
2868 #[test]
2870 fn join_filter_removed() -> Result<()> {
2871 let table_scan = test_table_scan()?;
2872 let left = LogicalPlanBuilder::from(table_scan)
2873 .project(vec![col("a"), col("b"), col("c")])?
2874 .build()?;
2875 let right_table_scan = test_table_scan_with_name("test2")?;
2876 let right = LogicalPlanBuilder::from(right_table_scan)
2877 .project(vec![col("a"), col("b"), col("c")])?
2878 .build()?;
2879 let filter = col("test.b")
2880 .gt(lit(1u32))
2881 .and(col("test2.c").gt(lit(4u32)));
2882 let plan = LogicalPlanBuilder::from(left)
2883 .join(
2884 right,
2885 JoinType::Inner,
2886 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2887 Some(filter),
2888 )?
2889 .build()?;
2890
2891 assert_snapshot!(plan,
2893 @r"
2894 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2895 Projection: test.a, test.b, test.c
2896 TableScan: test
2897 Projection: test2.a, test2.b, test2.c
2898 TableScan: test2
2899 ",
2900 );
2901 assert_optimized_plan_equal!(
2902 plan,
2903 @r"
2904 Inner Join: test.a = test2.a
2905 Projection: test.a, test.b, test.c
2906 TableScan: test, full_filters=[test.b > UInt32(1)]
2907 Projection: test2.a, test2.b, test2.c
2908 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2909 "
2910 )
2911 }
2912
2913 #[test]
2915 fn join_filter_on_common() -> Result<()> {
2916 let table_scan = test_table_scan()?;
2917 let left = LogicalPlanBuilder::from(table_scan)
2918 .project(vec![col("a")])?
2919 .build()?;
2920 let right_table_scan = test_table_scan_with_name("test2")?;
2921 let right = LogicalPlanBuilder::from(right_table_scan)
2922 .project(vec![col("b")])?
2923 .build()?;
2924 let filter = col("test.a").gt(lit(1u32));
2925 let plan = LogicalPlanBuilder::from(left)
2926 .join(
2927 right,
2928 JoinType::Inner,
2929 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2930 Some(filter),
2931 )?
2932 .build()?;
2933
2934 assert_snapshot!(plan,
2936 @r"
2937 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2938 Projection: test.a
2939 TableScan: test
2940 Projection: test2.b
2941 TableScan: test2
2942 ",
2943 );
2944 assert_optimized_plan_equal!(
2945 plan,
2946 @r"
2947 Inner Join: test.a = test2.b
2948 Projection: test.a
2949 TableScan: test, full_filters=[test.a > UInt32(1)]
2950 Projection: test2.b
2951 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2952 "
2953 )
2954 }
2955
2956 #[test]
2958 fn left_join_on_with_filter() -> Result<()> {
2959 let table_scan = test_table_scan()?;
2960 let left = LogicalPlanBuilder::from(table_scan)
2961 .project(vec![col("a"), col("b"), col("c")])?
2962 .build()?;
2963 let right_table_scan = test_table_scan_with_name("test2")?;
2964 let right = LogicalPlanBuilder::from(right_table_scan)
2965 .project(vec![col("a"), col("b"), col("c")])?
2966 .build()?;
2967 let filter = col("test.a")
2968 .gt(lit(1u32))
2969 .and(col("test.b").lt(col("test2.b")))
2970 .and(col("test2.c").gt(lit(4u32)));
2971 let plan = LogicalPlanBuilder::from(left)
2972 .join(
2973 right,
2974 JoinType::Left,
2975 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2976 Some(filter),
2977 )?
2978 .build()?;
2979
2980 assert_snapshot!(plan,
2982 @r"
2983 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2984 Projection: test.a, test.b, test.c
2985 TableScan: test
2986 Projection: test2.a, test2.b, test2.c
2987 TableScan: test2
2988 ",
2989 );
2990 assert_optimized_plan_equal!(
2991 plan,
2992 @r"
2993 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2994 Projection: test.a, test.b, test.c
2995 TableScan: test
2996 Projection: test2.a, test2.b, test2.c
2997 TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)]
2998 "
2999 )
3000 }
3001
3002 #[test]
3004 fn right_join_on_with_filter() -> Result<()> {
3005 let table_scan = test_table_scan()?;
3006 let left = LogicalPlanBuilder::from(table_scan)
3007 .project(vec![col("a"), col("b"), col("c")])?
3008 .build()?;
3009 let right_table_scan = test_table_scan_with_name("test2")?;
3010 let right = LogicalPlanBuilder::from(right_table_scan)
3011 .project(vec![col("a"), col("b"), col("c")])?
3012 .build()?;
3013 let filter = col("test.a")
3014 .gt(lit(1u32))
3015 .and(col("test.b").lt(col("test2.b")))
3016 .and(col("test2.c").gt(lit(4u32)));
3017 let plan = LogicalPlanBuilder::from(left)
3018 .join(
3019 right,
3020 JoinType::Right,
3021 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3022 Some(filter),
3023 )?
3024 .build()?;
3025
3026 assert_snapshot!(plan,
3028 @r"
3029 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3030 Projection: test.a, test.b, test.c
3031 TableScan: test
3032 Projection: test2.a, test2.b, test2.c
3033 TableScan: test2
3034 ",
3035 );
3036 assert_optimized_plan_equal!(
3037 plan,
3038 @r"
3039 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3040 Projection: test.a, test.b, test.c
3041 TableScan: test, full_filters=[test.a > UInt32(1)]
3042 Projection: test2.a, test2.b, test2.c
3043 TableScan: test2
3044 "
3045 )
3046 }
3047
3048 #[test]
3050 fn full_join_on_with_filter() -> Result<()> {
3051 let table_scan = test_table_scan()?;
3052 let left = LogicalPlanBuilder::from(table_scan)
3053 .project(vec![col("a"), col("b"), col("c")])?
3054 .build()?;
3055 let right_table_scan = test_table_scan_with_name("test2")?;
3056 let right = LogicalPlanBuilder::from(right_table_scan)
3057 .project(vec![col("a"), col("b"), col("c")])?
3058 .build()?;
3059 let filter = col("test.a")
3060 .gt(lit(1u32))
3061 .and(col("test.b").lt(col("test2.b")))
3062 .and(col("test2.c").gt(lit(4u32)));
3063 let plan = LogicalPlanBuilder::from(left)
3064 .join(
3065 right,
3066 JoinType::Full,
3067 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3068 Some(filter),
3069 )?
3070 .build()?;
3071
3072 assert_snapshot!(plan,
3074 @r"
3075 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3076 Projection: test.a, test.b, test.c
3077 TableScan: test
3078 Projection: test2.a, test2.b, test2.c
3079 TableScan: test2
3080 ",
3081 );
3082 assert_optimized_plan_equal!(
3083 plan,
3084 @r"
3085 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3086 Projection: test.a, test.b, test.c
3087 TableScan: test
3088 Projection: test2.a, test2.b, test2.c
3089 TableScan: test2
3090 "
3091 )
3092 }
3093
3094 struct PushDownProvider {
3095 pub filter_support: TableProviderFilterPushDown,
3096 }
3097
3098 #[async_trait]
3099 impl TableSource for PushDownProvider {
3100 fn schema(&self) -> SchemaRef {
3101 Arc::new(Schema::new(vec![
3102 Field::new("a", DataType::Int32, true),
3103 Field::new("b", DataType::Int32, true),
3104 ]))
3105 }
3106
3107 fn table_type(&self) -> TableType {
3108 TableType::Base
3109 }
3110
3111 fn supports_filters_pushdown(
3112 &self,
3113 filters: &[&Expr],
3114 ) -> Result<Vec<TableProviderFilterPushDown>> {
3115 Ok((0..filters.len())
3116 .map(|_| self.filter_support.clone())
3117 .collect())
3118 }
3119
3120 fn as_any(&self) -> &dyn Any {
3121 self
3122 }
3123 }
3124
3125 fn table_scan_with_pushdown_provider_builder(
3126 filter_support: TableProviderFilterPushDown,
3127 filters: Vec<Expr>,
3128 projection: Option<Vec<usize>>,
3129 ) -> Result<LogicalPlanBuilder> {
3130 let test_provider = PushDownProvider { filter_support };
3131
3132 let table_scan = LogicalPlan::TableScan(TableScan {
3133 table_name: "test".into(),
3134 filters,
3135 projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3136 projection,
3137 source: Arc::new(test_provider),
3138 fetch: None,
3139 });
3140
3141 Ok(LogicalPlanBuilder::from(table_scan))
3142 }
3143
3144 fn table_scan_with_pushdown_provider(
3145 filter_support: TableProviderFilterPushDown,
3146 ) -> Result<LogicalPlan> {
3147 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3148 .filter(col("a").eq(lit(1i64)))?
3149 .build()
3150 }
3151
3152 #[test]
3153 fn filter_with_table_provider_exact() -> Result<()> {
3154 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3155
3156 assert_optimized_plan_equal!(
3157 plan,
3158 @"TableScan: test, full_filters=[a = Int64(1)]"
3159 )
3160 }
3161
3162 #[test]
3163 fn filter_with_table_provider_inexact() -> Result<()> {
3164 let plan =
3165 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3166
3167 assert_optimized_plan_equal!(
3168 plan,
3169 @r"
3170 Filter: a = Int64(1)
3171 TableScan: test, partial_filters=[a = Int64(1)]
3172 "
3173 )
3174 }
3175
3176 #[test]
3177 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3178 let plan =
3179 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3180
3181 let optimized_plan = PushDownFilter::new()
3182 .rewrite(plan, &OptimizerContext::new())
3183 .expect("failed to optimize plan")
3184 .data;
3185
3186 assert_optimized_plan_equal!(
3189 optimized_plan,
3190 @r"
3191 Filter: a = Int64(1)
3192 TableScan: test, partial_filters=[a = Int64(1)]
3193 "
3194 )
3195 }
3196
3197 #[test]
3198 fn filter_with_table_provider_unsupported() -> Result<()> {
3199 let plan =
3200 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3201
3202 assert_optimized_plan_equal!(
3203 plan,
3204 @r"
3205 Filter: a = Int64(1)
3206 TableScan: test
3207 "
3208 )
3209 }
3210
3211 #[test]
3212 fn multi_combined_filter() -> Result<()> {
3213 let plan = table_scan_with_pushdown_provider_builder(
3214 TableProviderFilterPushDown::Inexact,
3215 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3216 Some(vec![0]),
3217 )?
3218 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3219 .project(vec![col("a"), col("b")])?
3220 .build()?;
3221
3222 assert_optimized_plan_equal!(
3223 plan,
3224 @r"
3225 Projection: a, b
3226 Filter: a = Int64(10) AND b > Int64(11)
3227 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3228 "
3229 )
3230 }
3231
3232 #[test]
3233 fn multi_combined_filter_exact() -> Result<()> {
3234 let plan = table_scan_with_pushdown_provider_builder(
3235 TableProviderFilterPushDown::Exact,
3236 vec![],
3237 Some(vec![0]),
3238 )?
3239 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3240 .project(vec![col("a"), col("b")])?
3241 .build()?;
3242
3243 assert_optimized_plan_equal!(
3244 plan,
3245 @r"
3246 Projection: a, b
3247 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3248 "
3249 )
3250 }
3251
3252 #[test]
3253 fn test_filter_with_alias() -> Result<()> {
3254 let table_scan = test_table_scan()?;
3258 let plan = LogicalPlanBuilder::from(table_scan)
3259 .project(vec![col("a").alias("b"), col("c")])?
3260 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3261 .build()?;
3262
3263 assert_snapshot!(plan,
3265 @r"
3266 Filter: b > Int64(10) AND test.c > Int64(10)
3267 Projection: test.a AS b, test.c
3268 TableScan: test
3269 ",
3270 );
3271 assert_optimized_plan_equal!(
3273 plan,
3274 @r"
3275 Projection: test.a AS b, test.c
3276 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3277 "
3278 )
3279 }
3280
3281 #[test]
3282 fn test_filter_with_alias_2() -> Result<()> {
3283 let table_scan = test_table_scan()?;
3287 let plan = LogicalPlanBuilder::from(table_scan)
3288 .project(vec![col("a").alias("b"), col("c")])?
3289 .project(vec![col("b"), col("c")])?
3290 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3291 .build()?;
3292
3293 assert_snapshot!(plan,
3295 @r"
3296 Filter: b > Int64(10) AND test.c > Int64(10)
3297 Projection: b, test.c
3298 Projection: test.a AS b, test.c
3299 TableScan: test
3300 ",
3301 );
3302 assert_optimized_plan_equal!(
3304 plan,
3305 @r"
3306 Projection: b, test.c
3307 Projection: test.a AS b, test.c
3308 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3309 "
3310 )
3311 }
3312
3313 #[test]
3314 fn test_filter_with_multi_alias() -> Result<()> {
3315 let table_scan = test_table_scan()?;
3316 let plan = LogicalPlanBuilder::from(table_scan)
3317 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3318 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3319 .build()?;
3320
3321 assert_snapshot!(plan,
3323 @r"
3324 Filter: b > Int64(10) AND d > Int64(10)
3325 Projection: test.a AS b, test.c AS d
3326 TableScan: test
3327 ",
3328 );
3329 assert_optimized_plan_equal!(
3331 plan,
3332 @r"
3333 Projection: test.a AS b, test.c AS d
3334 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3335 "
3336 )
3337 }
3338
3339 #[test]
3341 fn join_filter_with_alias() -> Result<()> {
3342 let table_scan = test_table_scan()?;
3343 let left = LogicalPlanBuilder::from(table_scan)
3344 .project(vec![col("a").alias("c")])?
3345 .build()?;
3346 let right_table_scan = test_table_scan_with_name("test2")?;
3347 let right = LogicalPlanBuilder::from(right_table_scan)
3348 .project(vec![col("b").alias("d")])?
3349 .build()?;
3350 let filter = col("c").gt(lit(1u32));
3351 let plan = LogicalPlanBuilder::from(left)
3352 .join(
3353 right,
3354 JoinType::Inner,
3355 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3356 Some(filter),
3357 )?
3358 .build()?;
3359
3360 assert_snapshot!(plan,
3361 @r"
3362 Inner Join: c = d Filter: c > UInt32(1)
3363 Projection: test.a AS c
3364 TableScan: test
3365 Projection: test2.b AS d
3366 TableScan: test2
3367 ",
3368 );
3369 assert_optimized_plan_equal!(
3371 plan,
3372 @r"
3373 Inner Join: c = d
3374 Projection: test.a AS c
3375 TableScan: test, full_filters=[test.a > UInt32(1)]
3376 Projection: test2.b AS d
3377 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3378 "
3379 )
3380 }
3381
3382 #[test]
3383 fn test_in_filter_with_alias() -> Result<()> {
3384 let table_scan = test_table_scan()?;
3388 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3389 let plan = LogicalPlanBuilder::from(table_scan)
3390 .project(vec![col("a").alias("b"), col("c")])?
3391 .filter(in_list(col("b"), filter_value, false))?
3392 .build()?;
3393
3394 assert_snapshot!(plan,
3396 @r"
3397 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3398 Projection: test.a AS b, test.c
3399 TableScan: test
3400 ",
3401 );
3402 assert_optimized_plan_equal!(
3404 plan,
3405 @r"
3406 Projection: test.a AS b, test.c
3407 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3408 "
3409 )
3410 }
3411
3412 #[test]
3413 fn test_in_filter_with_alias_2() -> Result<()> {
3414 let table_scan = test_table_scan()?;
3418 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3419 let plan = LogicalPlanBuilder::from(table_scan)
3420 .project(vec![col("a").alias("b"), col("c")])?
3421 .project(vec![col("b"), col("c")])?
3422 .filter(in_list(col("b"), filter_value, false))?
3423 .build()?;
3424
3425 assert_snapshot!(plan,
3427 @r"
3428 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3429 Projection: b, test.c
3430 Projection: test.a AS b, test.c
3431 TableScan: test
3432 ",
3433 );
3434 assert_optimized_plan_equal!(
3436 plan,
3437 @r"
3438 Projection: b, test.c
3439 Projection: test.a AS b, test.c
3440 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3441 "
3442 )
3443 }
3444
3445 #[test]
3446 fn test_in_subquery_with_alias() -> Result<()> {
3447 let table_scan = test_table_scan()?;
3450 let table_scan_sq = test_table_scan_with_name("sq")?;
3451 let subplan = Arc::new(
3452 LogicalPlanBuilder::from(table_scan_sq)
3453 .project(vec![col("c")])?
3454 .build()?,
3455 );
3456 let plan = LogicalPlanBuilder::from(table_scan)
3457 .project(vec![col("a").alias("b"), col("c")])?
3458 .filter(in_subquery(col("b"), subplan))?
3459 .build()?;
3460
3461 assert_snapshot!(plan,
3463 @r"
3464 Filter: b IN (<subquery>)
3465 Subquery:
3466 Projection: sq.c
3467 TableScan: sq
3468 Projection: test.a AS b, test.c
3469 TableScan: test
3470 ",
3471 );
3472 assert_optimized_plan_equal!(
3474 plan,
3475 @r"
3476 Projection: test.a AS b, test.c
3477 TableScan: test, full_filters=[test.a IN (<subquery>)]
3478 Subquery:
3479 Projection: sq.c
3480 TableScan: sq
3481 "
3482 )
3483 }
3484
3485 #[test]
3486 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3487 let plan = LogicalPlanBuilder::empty(true)
3489 .project(vec![lit(0i64).alias("a")])?
3490 .alias("b")?
3491 .project(vec![col("b.a")])?
3492 .alias("b")?
3493 .filter(col("b.a").eq(lit(1i64)))?
3494 .project(vec![col("b.a")])?
3495 .build()?;
3496
3497 assert_snapshot!(plan,
3498 @r"
3499 Projection: b.a
3500 Filter: b.a = Int64(1)
3501 SubqueryAlias: b
3502 Projection: b.a
3503 SubqueryAlias: b
3504 Projection: Int64(0) AS a
3505 EmptyRelation: rows=1
3506 ",
3507 );
3508 assert_optimized_plan_equal!(
3511 plan,
3512 @r"
3513 Projection: b.a
3514 SubqueryAlias: b
3515 Projection: b.a
3516 SubqueryAlias: b
3517 Projection: Int64(0) AS a
3518 Filter: Int64(0) = Int64(1)
3519 EmptyRelation: rows=1
3520 "
3521 )
3522 }
3523
3524 #[test]
3525 fn test_crossjoin_with_or_clause() -> Result<()> {
3526 let table_scan = test_table_scan()?;
3528 let left = LogicalPlanBuilder::from(table_scan)
3529 .project(vec![col("a"), col("b"), col("c")])?
3530 .build()?;
3531 let right_table_scan = test_table_scan_with_name("test1")?;
3532 let right = LogicalPlanBuilder::from(right_table_scan)
3533 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3534 .build()?;
3535 let filter = or(
3536 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3537 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3538 );
3539 let plan = LogicalPlanBuilder::from(left)
3540 .cross_join(right)?
3541 .filter(filter)?
3542 .build()?;
3543
3544 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3545 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3546 Projection: test.a, test.b, test.c
3547 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3548 Projection: test1.a AS d, test1.a AS e
3549 TableScan: test1
3550 ")?;
3551
3552 let optimized_plan = PushDownFilter::new()
3555 .rewrite(plan, &OptimizerContext::new())
3556 .expect("failed to optimize plan")
3557 .data;
3558 assert_optimized_plan_equal!(
3559 optimized_plan,
3560 @r"
3561 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3562 Projection: test.a, test.b, test.c
3563 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3564 Projection: test1.a AS d, test1.a AS e
3565 TableScan: test1
3566 "
3567 )
3568 }
3569
3570 #[test]
3571 fn left_semi_join() -> Result<()> {
3572 let left = test_table_scan_with_name("test1")?;
3573 let right_table_scan = test_table_scan_with_name("test2")?;
3574 let right = LogicalPlanBuilder::from(right_table_scan)
3575 .project(vec![col("a"), col("b")])?
3576 .build()?;
3577 let plan = LogicalPlanBuilder::from(left)
3578 .join(
3579 right,
3580 JoinType::LeftSemi,
3581 (
3582 vec![Column::from_qualified_name("test1.a")],
3583 vec![Column::from_qualified_name("test2.a")],
3584 ),
3585 None,
3586 )?
3587 .filter(col("test2.a").lt_eq(lit(1i64)))?
3588 .build()?;
3589
3590 assert_snapshot!(plan,
3592 @r"
3593 Filter: test2.a <= Int64(1)
3594 LeftSemi Join: test1.a = test2.a
3595 TableScan: test1
3596 Projection: test2.a, test2.b
3597 TableScan: test2
3598 ",
3599 );
3600 assert_optimized_plan_equal!(
3602 plan,
3603 @r"
3604 Filter: test2.a <= Int64(1)
3605 LeftSemi Join: test1.a = test2.a
3606 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3607 Projection: test2.a, test2.b
3608 TableScan: test2
3609 "
3610 )
3611 }
3612
3613 #[test]
3614 fn left_semi_join_with_filters() -> Result<()> {
3615 let left = test_table_scan_with_name("test1")?;
3616 let right_table_scan = test_table_scan_with_name("test2")?;
3617 let right = LogicalPlanBuilder::from(right_table_scan)
3618 .project(vec![col("a"), col("b")])?
3619 .build()?;
3620 let plan = LogicalPlanBuilder::from(left)
3621 .join(
3622 right,
3623 JoinType::LeftSemi,
3624 (
3625 vec![Column::from_qualified_name("test1.a")],
3626 vec![Column::from_qualified_name("test2.a")],
3627 ),
3628 Some(
3629 col("test1.b")
3630 .gt(lit(1u32))
3631 .and(col("test2.b").gt(lit(2u32))),
3632 ),
3633 )?
3634 .build()?;
3635
3636 assert_snapshot!(plan,
3638 @r"
3639 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3640 TableScan: test1
3641 Projection: test2.a, test2.b
3642 TableScan: test2
3643 ",
3644 );
3645 assert_optimized_plan_equal!(
3647 plan,
3648 @r"
3649 LeftSemi Join: test1.a = test2.a
3650 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3651 Projection: test2.a, test2.b
3652 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3653 "
3654 )
3655 }
3656
3657 #[test]
3658 fn right_semi_join() -> Result<()> {
3659 let left = test_table_scan_with_name("test1")?;
3660 let right_table_scan = test_table_scan_with_name("test2")?;
3661 let right = LogicalPlanBuilder::from(right_table_scan)
3662 .project(vec![col("a"), col("b")])?
3663 .build()?;
3664 let plan = LogicalPlanBuilder::from(left)
3665 .join(
3666 right,
3667 JoinType::RightSemi,
3668 (
3669 vec![Column::from_qualified_name("test1.a")],
3670 vec![Column::from_qualified_name("test2.a")],
3671 ),
3672 None,
3673 )?
3674 .filter(col("test1.a").lt_eq(lit(1i64)))?
3675 .build()?;
3676
3677 assert_snapshot!(plan,
3679 @r"
3680 Filter: test1.a <= Int64(1)
3681 RightSemi Join: test1.a = test2.a
3682 TableScan: test1
3683 Projection: test2.a, test2.b
3684 TableScan: test2
3685 ",
3686 );
3687 assert_optimized_plan_equal!(
3689 plan,
3690 @r"
3691 Filter: test1.a <= Int64(1)
3692 RightSemi Join: test1.a = test2.a
3693 TableScan: test1
3694 Projection: test2.a, test2.b
3695 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3696 "
3697 )
3698 }
3699
3700 #[test]
3701 fn right_semi_join_with_filters() -> Result<()> {
3702 let left = test_table_scan_with_name("test1")?;
3703 let right_table_scan = test_table_scan_with_name("test2")?;
3704 let right = LogicalPlanBuilder::from(right_table_scan)
3705 .project(vec![col("a"), col("b")])?
3706 .build()?;
3707 let plan = LogicalPlanBuilder::from(left)
3708 .join(
3709 right,
3710 JoinType::RightSemi,
3711 (
3712 vec![Column::from_qualified_name("test1.a")],
3713 vec![Column::from_qualified_name("test2.a")],
3714 ),
3715 Some(
3716 col("test1.b")
3717 .gt(lit(1u32))
3718 .and(col("test2.b").gt(lit(2u32))),
3719 ),
3720 )?
3721 .build()?;
3722
3723 assert_snapshot!(plan,
3725 @r"
3726 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3727 TableScan: test1
3728 Projection: test2.a, test2.b
3729 TableScan: test2
3730 ",
3731 );
3732 assert_optimized_plan_equal!(
3734 plan,
3735 @r"
3736 RightSemi Join: test1.a = test2.a
3737 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3738 Projection: test2.a, test2.b
3739 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3740 "
3741 )
3742 }
3743
3744 #[test]
3745 fn left_anti_join() -> Result<()> {
3746 let table_scan = test_table_scan_with_name("test1")?;
3747 let left = LogicalPlanBuilder::from(table_scan)
3748 .project(vec![col("a"), col("b")])?
3749 .build()?;
3750 let right_table_scan = test_table_scan_with_name("test2")?;
3751 let right = LogicalPlanBuilder::from(right_table_scan)
3752 .project(vec![col("a"), col("b")])?
3753 .build()?;
3754 let plan = LogicalPlanBuilder::from(left)
3755 .join(
3756 right,
3757 JoinType::LeftAnti,
3758 (
3759 vec![Column::from_qualified_name("test1.a")],
3760 vec![Column::from_qualified_name("test2.a")],
3761 ),
3762 None,
3763 )?
3764 .filter(col("test2.a").gt(lit(2u32)))?
3765 .build()?;
3766
3767 assert_snapshot!(plan,
3769 @r"
3770 Filter: test2.a > UInt32(2)
3771 LeftAnti Join: test1.a = test2.a
3772 Projection: test1.a, test1.b
3773 TableScan: test1
3774 Projection: test2.a, test2.b
3775 TableScan: test2
3776 ",
3777 );
3778 assert_optimized_plan_equal!(
3780 plan,
3781 @r"
3782 Filter: test2.a > UInt32(2)
3783 LeftAnti Join: test1.a = test2.a
3784 Projection: test1.a, test1.b
3785 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3786 Projection: test2.a, test2.b
3787 TableScan: test2
3788 "
3789 )
3790 }
3791
3792 #[test]
3793 fn left_anti_join_with_filters() -> Result<()> {
3794 let table_scan = test_table_scan_with_name("test1")?;
3795 let left = LogicalPlanBuilder::from(table_scan)
3796 .project(vec![col("a"), col("b")])?
3797 .build()?;
3798 let right_table_scan = test_table_scan_with_name("test2")?;
3799 let right = LogicalPlanBuilder::from(right_table_scan)
3800 .project(vec![col("a"), col("b")])?
3801 .build()?;
3802 let plan = LogicalPlanBuilder::from(left)
3803 .join(
3804 right,
3805 JoinType::LeftAnti,
3806 (
3807 vec![Column::from_qualified_name("test1.a")],
3808 vec![Column::from_qualified_name("test2.a")],
3809 ),
3810 Some(
3811 col("test1.b")
3812 .gt(lit(1u32))
3813 .and(col("test2.b").gt(lit(2u32))),
3814 ),
3815 )?
3816 .build()?;
3817
3818 assert_snapshot!(plan,
3820 @r"
3821 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3822 Projection: test1.a, test1.b
3823 TableScan: test1
3824 Projection: test2.a, test2.b
3825 TableScan: test2
3826 ",
3827 );
3828 assert_optimized_plan_equal!(
3830 plan,
3831 @r"
3832 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3833 Projection: test1.a, test1.b
3834 TableScan: test1
3835 Projection: test2.a, test2.b
3836 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3837 "
3838 )
3839 }
3840
3841 #[test]
3842 fn right_anti_join() -> Result<()> {
3843 let table_scan = test_table_scan_with_name("test1")?;
3844 let left = LogicalPlanBuilder::from(table_scan)
3845 .project(vec![col("a"), col("b")])?
3846 .build()?;
3847 let right_table_scan = test_table_scan_with_name("test2")?;
3848 let right = LogicalPlanBuilder::from(right_table_scan)
3849 .project(vec![col("a"), col("b")])?
3850 .build()?;
3851 let plan = LogicalPlanBuilder::from(left)
3852 .join(
3853 right,
3854 JoinType::RightAnti,
3855 (
3856 vec![Column::from_qualified_name("test1.a")],
3857 vec![Column::from_qualified_name("test2.a")],
3858 ),
3859 None,
3860 )?
3861 .filter(col("test1.a").gt(lit(2u32)))?
3862 .build()?;
3863
3864 assert_snapshot!(plan,
3866 @r"
3867 Filter: test1.a > UInt32(2)
3868 RightAnti Join: test1.a = test2.a
3869 Projection: test1.a, test1.b
3870 TableScan: test1
3871 Projection: test2.a, test2.b
3872 TableScan: test2
3873 ",
3874 );
3875 assert_optimized_plan_equal!(
3877 plan,
3878 @r"
3879 Filter: test1.a > UInt32(2)
3880 RightAnti Join: test1.a = test2.a
3881 Projection: test1.a, test1.b
3882 TableScan: test1
3883 Projection: test2.a, test2.b
3884 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3885 "
3886 )
3887 }
3888
3889 #[test]
3890 fn right_anti_join_with_filters() -> Result<()> {
3891 let table_scan = test_table_scan_with_name("test1")?;
3892 let left = LogicalPlanBuilder::from(table_scan)
3893 .project(vec![col("a"), col("b")])?
3894 .build()?;
3895 let right_table_scan = test_table_scan_with_name("test2")?;
3896 let right = LogicalPlanBuilder::from(right_table_scan)
3897 .project(vec![col("a"), col("b")])?
3898 .build()?;
3899 let plan = LogicalPlanBuilder::from(left)
3900 .join(
3901 right,
3902 JoinType::RightAnti,
3903 (
3904 vec![Column::from_qualified_name("test1.a")],
3905 vec![Column::from_qualified_name("test2.a")],
3906 ),
3907 Some(
3908 col("test1.b")
3909 .gt(lit(1u32))
3910 .and(col("test2.b").gt(lit(2u32))),
3911 ),
3912 )?
3913 .build()?;
3914
3915 assert_snapshot!(plan,
3917 @r"
3918 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3919 Projection: test1.a, test1.b
3920 TableScan: test1
3921 Projection: test2.a, test2.b
3922 TableScan: test2
3923 ",
3924 );
3925 assert_optimized_plan_equal!(
3927 plan,
3928 @r"
3929 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3930 Projection: test1.a, test1.b
3931 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3932 Projection: test2.a, test2.b
3933 TableScan: test2
3934 "
3935 )
3936 }
3937
3938 #[derive(Debug, PartialEq, Eq, Hash)]
3939 struct TestScalarUDF {
3940 signature: Signature,
3941 }
3942
3943 impl ScalarUDFImpl for TestScalarUDF {
3944 fn as_any(&self) -> &dyn Any {
3945 self
3946 }
3947 fn name(&self) -> &str {
3948 "TestScalarUDF"
3949 }
3950
3951 fn signature(&self) -> &Signature {
3952 &self.signature
3953 }
3954
3955 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3956 Ok(DataType::Int32)
3957 }
3958
3959 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3960 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3961 }
3962 }
3963
3964 #[test]
3965 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3966 let table_scan = test_table_scan_with_name("test1")?;
3968 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3969 signature: Signature::exact(vec![], Volatility::Volatile),
3970 });
3971 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3972
3973 let plan = LogicalPlanBuilder::from(table_scan)
3974 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3975 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3976 .alias("t")?
3977 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3978 .project(vec![col("t.a"), col("t.r")])?
3979 .build()?;
3980
3981 assert_snapshot!(plan,
3982 @r"
3983 Projection: t.a, t.r
3984 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3985 SubqueryAlias: t
3986 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3987 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3988 TableScan: test1
3989 ",
3990 );
3991 assert_optimized_plan_equal!(
3992 plan,
3993 @r"
3994 Projection: t.a, t.r
3995 SubqueryAlias: t
3996 Filter: r > Float64(0.5)
3997 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3998 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3999 TableScan: test1, full_filters=[test1.a > Int32(5)]
4000 "
4001 )
4002 }
4003
4004 #[test]
4005 fn test_push_down_volatile_function_in_join() -> Result<()> {
4006 let table_scan = test_table_scan_with_name("test1")?;
4008 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4009 signature: Signature::exact(vec![], Volatility::Volatile),
4010 });
4011 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4012 let left = LogicalPlanBuilder::from(table_scan).build()?;
4013 let right_table_scan = test_table_scan_with_name("test2")?;
4014 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4015 let plan = LogicalPlanBuilder::from(left)
4016 .join(
4017 right,
4018 JoinType::Inner,
4019 (
4020 vec![Column::from_qualified_name("test1.a")],
4021 vec![Column::from_qualified_name("test2.a")],
4022 ),
4023 None,
4024 )?
4025 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4026 .alias("t")?
4027 .filter(col("t.r").gt(lit(0.8)))?
4028 .project(vec![col("t.a"), col("t.r")])?
4029 .build()?;
4030
4031 assert_snapshot!(plan,
4032 @r"
4033 Projection: t.a, t.r
4034 Filter: t.r > Float64(0.8)
4035 SubqueryAlias: t
4036 Projection: test1.a AS a, TestScalarUDF() AS r
4037 Inner Join: test1.a = test2.a
4038 TableScan: test1
4039 TableScan: test2
4040 ",
4041 );
4042 assert_optimized_plan_equal!(
4043 plan,
4044 @r"
4045 Projection: t.a, t.r
4046 SubqueryAlias: t
4047 Filter: r > Float64(0.8)
4048 Projection: test1.a AS a, TestScalarUDF() AS r
4049 Inner Join: test1.a = test2.a
4050 TableScan: test1
4051 TableScan: test2
4052 "
4053 )
4054 }
4055
4056 #[test]
4057 fn test_push_down_volatile_table_scan() -> Result<()> {
4058 let table_scan = test_table_scan()?;
4060 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4061 signature: Signature::exact(vec![], Volatility::Volatile),
4062 });
4063 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4064 let plan = LogicalPlanBuilder::from(table_scan)
4065 .project(vec![col("a"), col("b")])?
4066 .filter(expr.gt(lit(0.1)))?
4067 .build()?;
4068
4069 assert_snapshot!(plan,
4070 @r"
4071 Filter: TestScalarUDF() > Float64(0.1)
4072 Projection: test.a, test.b
4073 TableScan: test
4074 ",
4075 );
4076 assert_optimized_plan_equal!(
4077 plan,
4078 @r"
4079 Projection: test.a, test.b
4080 Filter: TestScalarUDF() > Float64(0.1)
4081 TableScan: test
4082 "
4083 )
4084 }
4085
4086 #[test]
4087 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4088 let table_scan = test_table_scan()?;
4090 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4091 signature: Signature::exact(vec![], Volatility::Volatile),
4092 });
4093 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4094 let plan = LogicalPlanBuilder::from(table_scan)
4095 .project(vec![col("a"), col("b")])?
4096 .filter(
4097 expr.gt(lit(0.1))
4098 .and(col("t.a").gt(lit(5)))
4099 .and(col("t.b").gt(lit(10))),
4100 )?
4101 .build()?;
4102
4103 assert_snapshot!(plan,
4104 @r"
4105 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4106 Projection: test.a, test.b
4107 TableScan: test
4108 ",
4109 );
4110 assert_optimized_plan_equal!(
4111 plan,
4112 @r"
4113 Projection: test.a, test.b
4114 Filter: TestScalarUDF() > Float64(0.1)
4115 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4116 "
4117 )
4118 }
4119
4120 #[test]
4121 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4122 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4124 signature: Signature::exact(vec![], Volatility::Volatile),
4125 });
4126 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4127 let plan = table_scan_with_pushdown_provider_builder(
4128 TableProviderFilterPushDown::Unsupported,
4129 vec![],
4130 None,
4131 )?
4132 .project(vec![col("a"), col("b")])?
4133 .filter(
4134 expr.gt(lit(0.1))
4135 .and(col("t.a").gt(lit(5)))
4136 .and(col("t.b").gt(lit(10))),
4137 )?
4138 .build()?;
4139
4140 assert_snapshot!(plan,
4141 @r"
4142 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4143 Projection: a, b
4144 TableScan: test
4145 ",
4146 );
4147 assert_optimized_plan_equal!(
4148 plan,
4149 @r"
4150 Projection: a, b
4151 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4152 TableScan: test
4153 "
4154 )
4155 }
4156
4157 #[test]
4158 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4159 #[derive(Debug, Hash, Eq, PartialEq)]
4161 struct TestUserNode {
4162 schema: DFSchemaRef,
4163 }
4164
4165 impl PartialOrd for TestUserNode {
4166 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4167 None
4168 }
4169 }
4170
4171 impl TestUserNode {
4172 fn new() -> Self {
4173 let schema = Arc::new(
4174 DFSchema::new_with_metadata(
4175 vec![(None, Field::new("a", DataType::Int64, false).into())],
4176 Default::default(),
4177 )
4178 .unwrap(),
4179 );
4180
4181 Self { schema }
4182 }
4183 }
4184
4185 impl UserDefinedLogicalNodeCore for TestUserNode {
4186 fn name(&self) -> &str {
4187 "test_node"
4188 }
4189
4190 fn inputs(&self) -> Vec<&LogicalPlan> {
4191 vec![]
4192 }
4193
4194 fn schema(&self) -> &DFSchemaRef {
4195 &self.schema
4196 }
4197
4198 fn expressions(&self) -> Vec<Expr> {
4199 vec![]
4200 }
4201
4202 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4203 write!(f, "TestUserNode")
4204 }
4205
4206 fn with_exprs_and_inputs(
4207 &self,
4208 exprs: Vec<Expr>,
4209 inputs: Vec<LogicalPlan>,
4210 ) -> Result<Self> {
4211 assert!(exprs.is_empty());
4212 assert!(inputs.is_empty());
4213 Ok(Self {
4214 schema: Arc::clone(&self.schema),
4215 })
4216 }
4217 }
4218
4219 let node = LogicalPlan::Extension(Extension {
4221 node: Arc::new(TestUserNode::new()),
4222 });
4223
4224 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4225
4226 assert_snapshot!(plan,
4228 @r"
4229 Filter: Boolean(false)
4230 TestUserNode
4231 ",
4232 );
4233 assert_optimized_plan_equal!(
4235 plan,
4236 @r"
4237 Filter: Boolean(false)
4238 TestUserNode
4239 "
4240 )
4241 }
4242
4243 #[test]
4249 fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> {
4250 let table_scan = test_table_scan()?;
4251
4252 let proj = LogicalPlanBuilder::from(table_scan)
4254 .project(vec![
4255 leaf_udf_expr(col("a")).alias("val"),
4256 col("b"),
4257 col("c"),
4258 ])?
4259 .build()?;
4260
4261 let plan = LogicalPlanBuilder::from(proj)
4263 .filter(col("val").gt(lit(150i64)))?
4264 .build()?;
4265
4266 assert_optimized_plan_equal!(
4268 plan,
4269 @r"
4270 Filter: val > Int64(150)
4271 Projection: leaf_udf(test.a) AS val, test.b, test.c
4272 TableScan: test
4273 "
4274 )
4275 }
4276
4277 #[test]
4279 fn filter_mixed_predicates_partial_push() -> Result<()> {
4280 let table_scan = test_table_scan()?;
4281
4282 let proj = LogicalPlanBuilder::from(table_scan)
4284 .project(vec![
4285 leaf_udf_expr(col("a")).alias("val"),
4286 col("b"),
4287 col("c"),
4288 ])?
4289 .build()?;
4290
4291 let plan = LogicalPlanBuilder::from(proj)
4293 .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))?
4294 .build()?;
4295
4296 assert_optimized_plan_equal!(
4298 plan,
4299 @r"
4300 Filter: val > Int64(150)
4301 Projection: leaf_udf(test.a) AS val, test.b, test.c
4302 TableScan: test, full_filters=[test.b > Int64(5)]
4303 "
4304 )
4305 }
4306
4307 #[test]
4308 fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> {
4309 let scan = test_table_scan()?;
4310 let scan_with_fetch = match scan {
4311 LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan {
4312 fetch: Some(10),
4313 ..scan
4314 }),
4315 _ => unreachable!(),
4316 };
4317 let plan = LogicalPlanBuilder::from(scan_with_fetch)
4318 .filter(col("a").gt(lit(10i64)))?
4319 .build()?;
4320 assert_optimized_plan_equal!(
4322 plan,
4323 @r"
4324 Filter: test.a > Int64(10)
4325 TableScan: test, fetch=10
4326 "
4327 )
4328 }
4329
4330 #[test]
4331 fn filter_push_down_through_sort_without_fetch() -> Result<()> {
4332 let table_scan = test_table_scan()?;
4333 let plan = LogicalPlanBuilder::from(table_scan)
4334 .sort(vec![col("a").sort(true, true)])?
4335 .filter(col("a").gt(lit(10i64)))?
4336 .build()?;
4337 assert_optimized_plan_equal!(
4339 plan,
4340 @r"
4341 Sort: test.a ASC NULLS FIRST
4342 TableScan: test, full_filters=[test.a > Int64(10)]
4343 "
4344 )
4345 }
4346
4347 #[test]
4348 fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> {
4349 let table_scan = test_table_scan()?;
4350 let plan = LogicalPlanBuilder::from(table_scan)
4351 .sort_with_limit(vec![col("a").sort(true, true)], Some(5))?
4352 .filter(col("a").gt(lit(10i64)))?
4353 .build()?;
4354 assert_optimized_plan_equal!(
4357 plan,
4358 @r"
4359 Filter: test.a > Int64(10)
4360 Sort: test.a ASC NULLS FIRST, fetch=5
4361 TableScan: test
4362 "
4363 )
4364 }
4365}