1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26use log::{Level, debug, log_enabled};
27
28use datafusion_common::instant::Instant;
29use datafusion_common::tree_node::{
30 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
31};
32use datafusion_common::{
33 Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
34 internal_err, plan_err, qualified_name,
35};
36use datafusion_expr::expr::WindowFunction;
37use datafusion_expr::expr_rewriter::replace_col;
38use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
39use datafusion_expr::utils::{
40 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
41};
42use datafusion_expr::{
43 BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
44};
45
46use crate::optimizer::ApplyOrder;
47use crate::simplify_expressions::simplify_predicates;
48use crate::utils::{
49 ColumnReference, has_all_column_refs, is_restrict_null_predicate, schema_columns,
50};
51use crate::{OptimizerConfig, OptimizerRule};
52use datafusion_expr::ExpressionPlacement;
53
54#[derive(Default, Debug)]
142pub struct PushDownFilter {}
143
144pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
169 match join_type {
170 JoinType::Inner => (true, true),
171 JoinType::Left => (true, false),
172 JoinType::Right => (false, true),
173 JoinType::Full => (false, false),
174 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
177 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
180 }
181}
182
183pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
185 join_type.on_lr_is_preserved()
186}
187
188#[derive(Debug)]
191struct ColumnChecker<'a> {
192 left_schema: &'a DFSchema,
194 left_columns: Option<HashSet<ColumnReference<'a>>>,
196 right_schema: &'a DFSchema,
198 right_columns: Option<HashSet<ColumnReference<'a>>>,
200}
201
202impl<'a> ColumnChecker<'a> {
203 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
204 Self {
205 left_schema,
206 left_columns: None,
207 right_schema,
208 right_columns: None,
209 }
210 }
211
212 fn is_left_only(&mut self, predicate: &Expr) -> bool {
214 if self.left_columns.is_none() {
215 self.left_columns = Some(schema_columns(self.left_schema));
216 }
217 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
218 }
219
220 fn is_right_only(&mut self, predicate: &Expr) -> bool {
222 if self.right_columns.is_none() {
223 self.right_columns = Some(schema_columns(self.right_schema));
224 }
225 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
226 }
227}
228
229fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
231 let mut is_evaluate = true;
232 predicate.apply(|expr| match expr {
233 Expr::Column(_)
234 | Expr::Literal(_, _)
235 | Expr::Placeholder(_)
236 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
237 Expr::Exists { .. }
238 | Expr::InSubquery(_)
239 | Expr::SetComparison(_)
240 | Expr::ScalarSubquery(_)
241 | Expr::OuterReferenceColumn(_, _)
242 | Expr::Unnest(_) => {
243 is_evaluate = false;
244 Ok(TreeNodeRecursion::Stop)
245 }
246 Expr::Alias(_)
247 | Expr::BinaryExpr(_)
248 | Expr::Like(_)
249 | Expr::SimilarTo(_)
250 | Expr::Not(_)
251 | Expr::IsNotNull(_)
252 | Expr::IsNull(_)
253 | Expr::IsTrue(_)
254 | Expr::IsFalse(_)
255 | Expr::IsUnknown(_)
256 | Expr::IsNotTrue(_)
257 | Expr::IsNotFalse(_)
258 | Expr::IsNotUnknown(_)
259 | Expr::Negative(_)
260 | Expr::Between(_)
261 | Expr::Case(_)
262 | Expr::Cast(_)
263 | Expr::TryCast(_)
264 | Expr::InList { .. }
265 | Expr::ScalarFunction(_)
266 | Expr::HigherOrderFunction(_)
267 | Expr::Lambda(_)
268 | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue),
269 #[expect(deprecated)]
271 Expr::AggregateFunction(_)
272 | Expr::WindowFunction(_)
273 | Expr::Wildcard { .. }
274 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
275 })?;
276 Ok(is_evaluate)
277}
278
279fn extract_or_clauses_for_join<'a>(
313 filters: &'a [Expr],
314 schema_cols: &'a HashSet<ColumnReference>,
315) -> impl Iterator<Item = Expr> + 'a {
316 filters.iter().filter_map(move |expr| {
318 if let Expr::BinaryExpr(BinaryExpr {
319 left,
320 op: Operator::Or,
321 right,
322 }) = expr
323 {
324 let left_expr = extract_or_clause(left.as_ref(), schema_cols);
325 let right_expr = extract_or_clause(right.as_ref(), schema_cols);
326
327 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
329 return Some(or(left_expr, right_expr));
330 }
331 }
332 None
333 })
334}
335
336fn extract_or_clause(
348 expr: &Expr,
349 schema_columns: &HashSet<ColumnReference>,
350) -> Option<Expr> {
351 let mut predicate = None;
352
353 match expr {
354 Expr::BinaryExpr(BinaryExpr {
355 left: l_expr,
356 op: Operator::Or,
357 right: r_expr,
358 }) => {
359 let l_expr = extract_or_clause(l_expr, schema_columns);
360 let r_expr = extract_or_clause(r_expr, schema_columns);
361
362 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
363 predicate = Some(or(l_expr, r_expr));
364 }
365 }
366 Expr::BinaryExpr(BinaryExpr {
367 left: l_expr,
368 op: Operator::And,
369 right: r_expr,
370 }) => {
371 let l_expr = extract_or_clause(l_expr, schema_columns);
372 let r_expr = extract_or_clause(r_expr, schema_columns);
373
374 match (l_expr, r_expr) {
375 (Some(l_expr), Some(r_expr)) => {
376 predicate = Some(and(l_expr, r_expr));
377 }
378 (Some(l_expr), None) => {
379 predicate = Some(l_expr);
380 }
381 (None, Some(r_expr)) => {
382 predicate = Some(r_expr);
383 }
384 (None, None) => {
385 predicate = None;
386 }
387 }
388 }
389 _ => {
390 if has_all_column_refs(expr, schema_columns) {
391 predicate = Some(expr.clone());
392 }
393 }
394 }
395
396 predicate
397}
398
399fn push_down_all_join(
401 predicates: Vec<Expr>,
402 inferred_join_predicates: Vec<Expr>,
403 mut join: Join,
404 on_filter: Vec<Expr>,
405) -> Result<Transformed<LogicalPlan>> {
406 let is_inner_join = join.join_type == JoinType::Inner;
407 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
409
410 let left_schema = join.left.schema();
415 let right_schema = join.right.schema();
416
417 let left_schema_columns = schema_columns(left_schema.as_ref());
418 let right_schema_columns = schema_columns(right_schema.as_ref());
419
420 let mut left_push = vec![];
421 let mut right_push = vec![];
422 let mut keep_predicates = vec![];
423 let mut join_conditions = vec![];
424 let mut checker = ColumnChecker::new(left_schema, right_schema);
425 for predicate in predicates {
426 if left_preserved && checker.is_left_only(&predicate) {
427 left_push.push(predicate);
428 } else if right_preserved && checker.is_right_only(&predicate) {
429 right_push.push(predicate);
430 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
431 join_conditions.push(predicate);
434 } else {
435 keep_predicates.push(predicate);
436 }
437 }
438
439 for predicate in inferred_join_predicates {
441 if checker.is_left_only(&predicate) {
442 left_push.push(predicate);
443 } else if checker.is_right_only(&predicate) {
444 right_push.push(predicate);
445 }
446 }
447
448 let mut on_filter_join_conditions = vec![];
449 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
450
451 if !on_filter.is_empty() {
452 for on in on_filter {
453 if on_left_preserved && checker.is_left_only(&on) {
454 left_push.push(on)
455 } else if on_right_preserved && checker.is_right_only(&on) {
456 right_push.push(on)
457 } else {
458 on_filter_join_conditions.push(on)
459 }
460 }
461 }
462
463 if left_preserved {
466 left_push.extend(extract_or_clauses_for_join(
467 &keep_predicates,
468 &left_schema_columns,
469 ));
470 left_push.extend(extract_or_clauses_for_join(
471 &join_conditions,
472 &left_schema_columns,
473 ));
474 }
475 if right_preserved {
476 right_push.extend(extract_or_clauses_for_join(
477 &keep_predicates,
478 &right_schema_columns,
479 ));
480 right_push.extend(extract_or_clauses_for_join(
481 &join_conditions,
482 &right_schema_columns,
483 ));
484 }
485
486 if on_left_preserved {
489 left_push.extend(extract_or_clauses_for_join(
490 &on_filter_join_conditions,
491 &left_schema_columns,
492 ));
493 }
494 if on_right_preserved {
495 right_push.extend(extract_or_clauses_for_join(
496 &on_filter_join_conditions,
497 &right_schema_columns,
498 ));
499 }
500
501 if let Some(predicate) = conjunction(left_push) {
502 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
503 }
504 if let Some(predicate) = conjunction(right_push) {
505 join.right =
506 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
507 }
508
509 join_conditions.extend(on_filter_join_conditions);
511 join.filter = conjunction(join_conditions);
512
513 let plan = LogicalPlan::Join(join);
515 let plan = if let Some(predicate) = conjunction(keep_predicates) {
516 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
517 } else {
518 plan
519 };
520 Ok(Transformed::yes(plan))
521}
522
523fn push_down_join(
524 join: Join,
525 parent_predicate: Option<&Expr>,
526) -> Result<Transformed<LogicalPlan>> {
527 let predicates = parent_predicate
529 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
530
531 let on_filters = join
533 .filter
534 .as_ref()
535 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
536
537 let inferred_join_predicates = with_debug_timing("infer_join_predicates", || {
539 infer_join_predicates(&join, &predicates, &on_filters)
540 })?;
541
542 if log_enabled!(Level::Debug) {
543 debug!(
544 "push_down_filter: join_type={:?}, parent_predicates={}, on_filters={}, inferred_join_predicates={}",
545 join.join_type,
546 predicates.len(),
547 on_filters.len(),
548 inferred_join_predicates.len()
549 );
550 }
551
552 if on_filters.is_empty()
553 && predicates.is_empty()
554 && inferred_join_predicates.is_empty()
555 {
556 return Ok(Transformed::no(LogicalPlan::Join(join)));
557 }
558
559 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
560}
561
562fn infer_join_predicates(
572 join: &Join,
573 predicates: &[Expr],
574 on_filters: &[Expr],
575) -> Result<Vec<Expr>> {
576 let join_col_keys = join
578 .on
579 .iter()
580 .filter_map(|(l, r)| {
581 let left_col = l.try_as_col()?;
582 let right_col = r.try_as_col()?;
583 Some((left_col, right_col))
584 })
585 .collect::<Vec<_>>();
586
587 let join_type = join.join_type;
588
589 let mut inferred_predicates = InferredPredicates::new(join_type);
590
591 infer_join_predicates_from_predicates(
592 &join_col_keys,
593 predicates,
594 &mut inferred_predicates,
595 )?;
596
597 infer_join_predicates_from_on_filters(
598 &join_col_keys,
599 join_type,
600 on_filters,
601 &mut inferred_predicates,
602 )?;
603
604 Ok(inferred_predicates.predicates)
605}
606
607struct InferredPredicates {
616 predicates: Vec<Expr>,
617 is_inner_join: bool,
618}
619
620impl InferredPredicates {
621 fn new(join_type: JoinType) -> Self {
622 Self {
623 predicates: vec![],
624 is_inner_join: join_type == JoinType::Inner,
625 }
626 }
627
628 fn try_build_predicate(
629 &mut self,
630 predicate: Expr,
631 replace_map: &HashMap<&Column, &Column>,
632 ) -> Result<()> {
633 if self.is_inner_join
634 || matches!(
635 is_restrict_null_predicate(
636 predicate.clone(),
637 replace_map.keys().cloned()
638 ),
639 Ok(true)
640 )
641 {
642 self.predicates.push(replace_col(predicate, replace_map)?);
643 }
644
645 Ok(())
646 }
647}
648
649fn infer_join_predicates_from_predicates(
658 join_col_keys: &[(&Column, &Column)],
659 predicates: &[Expr],
660 inferred_predicates: &mut InferredPredicates,
661) -> Result<()> {
662 infer_join_predicates_impl::<true, true>(
663 join_col_keys,
664 predicates,
665 inferred_predicates,
666 )
667}
668
669fn infer_join_predicates_from_on_filters(
681 join_col_keys: &[(&Column, &Column)],
682 join_type: JoinType,
683 on_filters: &[Expr],
684 inferred_predicates: &mut InferredPredicates,
685) -> Result<()> {
686 match join_type {
687 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
688 JoinType::Inner => infer_join_predicates_impl::<true, true>(
689 join_col_keys,
690 on_filters,
691 inferred_predicates,
692 ),
693 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
694 infer_join_predicates_impl::<true, false>(
695 join_col_keys,
696 on_filters,
697 inferred_predicates,
698 )
699 }
700 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
701 infer_join_predicates_impl::<false, true>(
702 join_col_keys,
703 on_filters,
704 inferred_predicates,
705 )
706 }
707 }
708}
709
710fn infer_join_predicates_impl<
726 const ENABLE_LEFT_TO_RIGHT: bool,
727 const ENABLE_RIGHT_TO_LEFT: bool,
728>(
729 join_col_keys: &[(&Column, &Column)],
730 input_predicates: &[Expr],
731 inferred_predicates: &mut InferredPredicates,
732) -> Result<()> {
733 for predicate in input_predicates {
734 let mut join_cols_to_replace = HashMap::new();
735
736 for &col in &predicate.column_refs() {
737 for (l, r) in join_col_keys.iter() {
738 if ENABLE_LEFT_TO_RIGHT && col == *l {
739 join_cols_to_replace.insert(col, *r);
740 break;
741 }
742 if ENABLE_RIGHT_TO_LEFT && col == *r {
743 join_cols_to_replace.insert(col, *l);
744 break;
745 }
746 }
747 }
748 if join_cols_to_replace.is_empty() {
749 continue;
750 }
751
752 inferred_predicates
753 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
754 }
755 Ok(())
756}
757
758impl OptimizerRule for PushDownFilter {
759 fn name(&self) -> &str {
760 "push_down_filter"
761 }
762
763 fn apply_order(&self) -> Option<ApplyOrder> {
764 Some(ApplyOrder::TopDown)
765 }
766
767 fn supports_rewrite(&self) -> bool {
768 true
769 }
770
771 fn rewrite(
772 &self,
773 plan: LogicalPlan,
774 config: &dyn OptimizerConfig,
775 ) -> Result<Transformed<LogicalPlan>> {
776 let _ = config.options();
777 if let LogicalPlan::Join(join) = plan {
778 return push_down_join(join, None);
779 };
780
781 let plan_schema = Arc::clone(plan.schema());
782
783 let LogicalPlan::Filter(mut filter) = plan else {
784 return Ok(Transformed::no(plan));
785 };
786
787 let predicate = split_conjunction_owned(filter.predicate.clone());
788 let old_predicate_len = predicate.len();
789 let new_predicates =
790 with_debug_timing("simplify_predicates", || simplify_predicates(predicate))?;
791 if log_enabled!(Level::Debug) {
792 debug!(
793 "push_down_filter: simplify_predicates old_count={}, new_count={}",
794 old_predicate_len,
795 new_predicates.len()
796 );
797 }
798 if old_predicate_len != new_predicates.len() {
799 let Some(new_predicate) = conjunction(new_predicates) else {
800 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
803 };
804 filter.predicate = new_predicate;
805 }
806
807 if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() {
811 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
812 }
813
814 match Arc::unwrap_or_clone(filter.input) {
815 LogicalPlan::Filter(child_filter) => {
816 let new_predicates = split_conjunction_owned(child_filter.predicate)
818 .into_iter()
819 .chain(split_conjunction_owned(filter.predicate))
820 .collect::<IndexSet<_>>();
822
823 let Some(new_predicate) = conjunction(new_predicates) else {
824 return plan_err!("at least one expression exists");
825 };
826
827 let new_filter = LogicalPlan::Filter(Filter::try_new(
828 new_predicate,
829 child_filter.input,
830 )?);
831
832 self.rewrite(new_filter, config)
833 }
834 LogicalPlan::Repartition(repartition) => {
835 let new_filter =
836 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
837 .map(LogicalPlan::Filter)?;
838 insert_below(LogicalPlan::Repartition(repartition), new_filter)
839 }
840 LogicalPlan::Distinct(distinct) => {
841 let new_filter =
842 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
843 .map(LogicalPlan::Filter)?;
844 insert_below(LogicalPlan::Distinct(distinct), new_filter)
845 }
846 LogicalPlan::Sort(sort) => {
847 let new_filter =
848 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
849 .map(LogicalPlan::Filter)?;
850 insert_below(LogicalPlan::Sort(sort), new_filter)
851 }
852 LogicalPlan::SubqueryAlias(subquery_alias) => {
853 let mut replace_map = HashMap::new();
854 for (i, (qualifier, field)) in
855 subquery_alias.input.schema().iter().enumerate()
856 {
857 let (sub_qualifier, sub_field) =
858 subquery_alias.schema.qualified_field(i);
859 replace_map.insert(
860 qualified_name(sub_qualifier, sub_field.name()),
861 Expr::Column(Column::new(qualifier.cloned(), field.name())),
862 );
863 }
864 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
865
866 let new_filter = LogicalPlan::Filter(Filter::try_new(
867 new_predicate,
868 Arc::clone(&subquery_alias.input),
869 )?);
870 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
871 }
872 LogicalPlan::Projection(projection) => {
873 let predicates = split_conjunction_owned(filter.predicate.clone());
874 let (new_projection, keep_predicate) =
875 rewrite_projection(predicates, projection)?;
876 if new_projection.transformed {
877 match keep_predicate {
878 None => Ok(new_projection),
879 Some(keep_predicate) => new_projection.map_data(|child_plan| {
880 Filter::try_new(keep_predicate, Arc::new(child_plan))
881 .map(LogicalPlan::Filter)
882 }),
883 }
884 } else {
885 filter.input = Arc::new(new_projection.data);
886 Ok(Transformed::no(LogicalPlan::Filter(filter)))
887 }
888 }
889 LogicalPlan::Unnest(mut unnest) => {
890 let predicates = split_conjunction_owned(filter.predicate.clone());
891 let mut non_unnest_predicates = vec![];
892 let mut unnest_predicates = vec![];
893 let mut unnest_struct_columns = vec![];
894
895 for idx in &unnest.struct_type_columns {
896 let (sub_qualifier, field) =
897 unnest.input.schema().qualified_field(*idx);
898 let field_name = field.name().clone();
899
900 if let DataType::Struct(children) = field.data_type() {
901 for child in children {
902 let child_name = child.name().clone();
903 unnest_struct_columns.push(Column::new(
904 sub_qualifier.cloned(),
905 format!("{field_name}.{child_name}"),
906 ));
907 }
908 }
909 }
910
911 for predicate in predicates {
912 let mut accum: HashSet<Column> = HashSet::new();
914 expr_to_columns(&predicate, &mut accum)?;
915
916 let contains_list_columns =
917 unnest.list_type_columns.iter().any(|(_, unnest_list)| {
918 accum.contains(&unnest_list.output_column)
919 });
920 let contains_struct_columns =
921 unnest_struct_columns.iter().any(|c| accum.contains(c));
922
923 if contains_list_columns || contains_struct_columns {
924 unnest_predicates.push(predicate);
925 } else {
926 non_unnest_predicates.push(predicate);
927 }
928 }
929
930 if non_unnest_predicates.is_empty() {
933 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
934 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
935 }
936
937 let unnest_input = std::mem::take(&mut unnest.input);
946
947 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
948 conjunction(non_unnest_predicates).unwrap(), unnest_input,
950 )?);
951
952 let unnest_plan =
956 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
957
958 match conjunction(unnest_predicates) {
959 None => Ok(unnest_plan),
960 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
961 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
962 ))),
963 }
964 }
965 LogicalPlan::Union(ref union) => {
966 let mut inputs = Vec::with_capacity(union.inputs.len());
967 for input in &union.inputs {
968 let mut replace_map = HashMap::new();
969 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
970 let (union_qualifier, union_field) =
971 union.schema.qualified_field(i);
972 replace_map.insert(
973 qualified_name(union_qualifier, union_field.name()),
974 Expr::Column(Column::new(qualifier.cloned(), field.name())),
975 );
976 }
977
978 let push_predicate =
979 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
980 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
981 push_predicate,
982 Arc::clone(input),
983 )?)))
984 }
985 Ok(Transformed::yes(LogicalPlan::Union(Union {
986 inputs,
987 schema: Arc::clone(&plan_schema),
988 })))
989 }
990 LogicalPlan::Aggregate(agg) => {
991 let group_expr_columns = agg
993 .group_expr
994 .iter()
995 .map(|e| {
996 let (relation, name) = e.qualified_name();
997 Column::new(relation, name)
998 })
999 .collect::<HashSet<_>>();
1000
1001 let predicates = split_conjunction_owned(filter.predicate);
1002
1003 let mut keep_predicates = vec![];
1004 let mut push_predicates = vec![];
1005 for expr in predicates {
1006 let cols = expr.column_refs();
1007 if cols.iter().all(|c| group_expr_columns.contains(c)) {
1008 push_predicates.push(expr);
1009 } else {
1010 keep_predicates.push(expr);
1011 }
1012 }
1013
1014 let mut replace_map = HashMap::new();
1018 for expr in &agg.group_expr {
1019 replace_map.insert(expr.schema_name().to_string(), expr.clone());
1020 }
1021 let replaced_push_predicates = push_predicates
1022 .into_iter()
1023 .map(|expr| replace_cols_by_name(expr, &replace_map))
1024 .collect::<Result<Vec<_>>>()?;
1025
1026 let agg_input = Arc::clone(&agg.input);
1027 Transformed::yes(LogicalPlan::Aggregate(agg))
1028 .transform_data(|new_plan| {
1029 if let Some(predicate) = conjunction(replaced_push_predicates) {
1031 let new_filter = make_filter(predicate, agg_input)?;
1032 insert_below(new_plan, new_filter)
1033 } else {
1034 Ok(Transformed::no(new_plan))
1035 }
1036 })?
1037 .map_data(|child_plan| {
1038 if let Some(predicate) = conjunction(keep_predicates) {
1041 make_filter(predicate, Arc::new(child_plan))
1042 } else {
1043 Ok(child_plan)
1044 }
1045 })
1046 }
1047 LogicalPlan::Window(window) => {
1058 let extract_partition_keys = |func: &WindowFunction| {
1064 func.params
1065 .partition_by
1066 .iter()
1067 .map(|c| {
1068 let (relation, name) = c.qualified_name();
1069 Column::new(relation, name)
1070 })
1071 .collect::<HashSet<_>>()
1072 };
1073 let potential_partition_keys = window
1074 .window_expr
1075 .iter()
1076 .map(|e| {
1077 match e {
1078 Expr::WindowFunction(window_func) => {
1079 extract_partition_keys(window_func)
1080 }
1081 Expr::Alias(alias) => {
1082 if let Expr::WindowFunction(window_func) =
1083 alias.expr.as_ref()
1084 {
1085 extract_partition_keys(window_func)
1086 } else {
1087 unreachable!()
1089 }
1090 }
1091 _ => {
1092 unreachable!()
1094 }
1095 }
1096 })
1097 .reduce(|a, b| &a & &b)
1100 .unwrap_or_default();
1101
1102 let predicates = split_conjunction_owned(filter.predicate);
1103 let mut keep_predicates = vec![];
1104 let mut push_predicates = vec![];
1105 for expr in predicates {
1106 let cols = expr.column_refs();
1107 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1108 push_predicates.push(expr);
1109 } else {
1110 keep_predicates.push(expr);
1111 }
1112 }
1113
1114 let window_input = Arc::clone(&window.input);
1123 Transformed::yes(LogicalPlan::Window(window))
1124 .transform_data(|new_plan| {
1125 if let Some(predicate) = conjunction(push_predicates) {
1127 let new_filter = make_filter(predicate, window_input)?;
1128 insert_below(new_plan, new_filter)
1129 } else {
1130 Ok(Transformed::no(new_plan))
1131 }
1132 })?
1133 .map_data(|child_plan| {
1134 if let Some(predicate) = conjunction(keep_predicates) {
1137 make_filter(predicate, Arc::new(child_plan))
1138 } else {
1139 Ok(child_plan)
1140 }
1141 })
1142 }
1143 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1144 LogicalPlan::TableScan(scan) => {
1145 let filter_predicates = split_conjunction(&filter.predicate);
1146
1147 let (subquery_filters, pushdown_candidates): (Vec<&Expr>, Vec<&Expr>) =
1151 filter_predicates
1152 .into_iter()
1153 .partition(|pred| pred.contains_scalar_subquery());
1154
1155 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1156 pushdown_candidates
1157 .into_iter()
1158 .partition(|pred| pred.is_volatile());
1159
1160 let supported_filters = scan
1162 .source
1163 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1164 assert_eq_or_internal_err!(
1165 non_volatile_filters.len(),
1166 supported_filters.len(),
1167 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1168 supported_filters.len(),
1169 non_volatile_filters.len()
1170 );
1171
1172 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1174
1175 let new_scan_filters = zip
1176 .clone()
1177 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1178 .map(|(pred, _)| pred);
1179
1180 let new_scan_filters: Vec<Expr> = scan
1182 .filters
1183 .iter()
1184 .chain(new_scan_filters)
1185 .unique()
1186 .cloned()
1187 .collect();
1188
1189 let new_predicate: Vec<Expr> = zip
1192 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1193 .map(|(pred, _)| pred)
1194 .chain(volatile_filters)
1195 .chain(subquery_filters)
1196 .cloned()
1197 .collect();
1198
1199 let new_scan = LogicalPlan::TableScan(TableScan {
1200 filters: new_scan_filters,
1201 ..scan
1202 });
1203
1204 Transformed::yes(new_scan).transform_data(|new_scan| {
1205 if let Some(predicate) = conjunction(new_predicate) {
1206 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1207 } else {
1208 Ok(Transformed::no(new_scan))
1209 }
1210 })
1211 }
1212 LogicalPlan::Extension(extension_plan) => {
1213 if extension_plan.node.inputs().is_empty() {
1216 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1217 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1218 }
1219 let prevent_cols =
1220 extension_plan.node.prevent_predicate_push_down_columns();
1221
1222 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1226 .iter()
1227 .map(|expr| {
1228 let cols = expr.column_refs();
1229 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1230 Ok(false) } else {
1232 Ok(true) }
1234 })
1235 .collect::<Result<Vec<_>>>()?;
1236
1237 if predicate_push_or_keep.iter().all(|&x| !x) {
1239 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1240 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1241 }
1242
1243 let mut keep_predicates = vec![];
1245 let mut push_predicates = vec![];
1246 for (push, expr) in predicate_push_or_keep
1247 .into_iter()
1248 .zip(split_conjunction_owned(filter.predicate))
1249 {
1250 if !push {
1251 keep_predicates.push(expr);
1252 } else {
1253 push_predicates.push(expr);
1254 }
1255 }
1256
1257 let new_children = match conjunction(push_predicates) {
1258 Some(predicate) => extension_plan
1259 .node
1260 .inputs()
1261 .into_iter()
1262 .map(|child| {
1263 Ok(LogicalPlan::Filter(Filter::try_new(
1264 predicate.clone(),
1265 Arc::new(child.clone()),
1266 )?))
1267 })
1268 .collect::<Result<Vec<_>>>()?,
1269 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1270 };
1271 let child_plan = LogicalPlan::Extension(extension_plan);
1273 let new_extension =
1274 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1275
1276 let new_plan = match conjunction(keep_predicates) {
1277 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1278 predicate,
1279 Arc::new(new_extension),
1280 )?),
1281 None => new_extension,
1282 };
1283 Ok(Transformed::yes(new_plan))
1284 }
1285 child => {
1286 filter.input = Arc::new(child);
1287 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1288 }
1289 }
1290 }
1291}
1292
1293fn rewrite_projection(
1321 predicates: Vec<Expr>,
1322 mut projection: Projection,
1323) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1324 let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection
1331 .schema
1332 .iter()
1333 .zip(projection.expr.iter())
1334 .map(|((qualifier, field), expr)| {
1335 let expr = expr.clone().unalias();
1337
1338 (qualified_name(qualifier, field.name()), expr)
1339 })
1340 .partition(|(_, value)| {
1341 value.is_volatile()
1342 || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes
1343 });
1344
1345 let mut push_predicates = vec![];
1346 let mut keep_predicates = vec![];
1347 for expr in predicates {
1348 if contain(&expr, &non_pushable_map) {
1349 keep_predicates.push(expr);
1350 } else {
1351 push_predicates.push(expr);
1352 }
1353 }
1354
1355 match conjunction(push_predicates) {
1356 Some(expr) => {
1357 let new_filter = LogicalPlan::Filter(Filter::try_new(
1360 replace_cols_by_name(expr, &pushable_map)?,
1361 std::mem::take(&mut projection.input),
1362 )?);
1363
1364 projection.input = Arc::new(new_filter);
1365
1366 Ok((
1367 Transformed::yes(LogicalPlan::Projection(projection)),
1368 conjunction(keep_predicates),
1369 ))
1370 }
1371 None => Ok((
1372 Transformed::no(LogicalPlan::Projection(projection)),
1373 conjunction(keep_predicates),
1374 )),
1375 }
1376}
1377
1378pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1380 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1381}
1382
1383fn insert_below(
1397 plan: LogicalPlan,
1398 new_child: LogicalPlan,
1399) -> Result<Transformed<LogicalPlan>> {
1400 let mut new_child = Some(new_child);
1401 let transformed_plan = plan.map_children(|_child| {
1402 if let Some(new_child) = new_child.take() {
1403 Ok(Transformed::yes(new_child))
1404 } else {
1405 internal_err!("node had more than one input")
1407 }
1408 })?;
1409
1410 assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1412
1413 Ok(transformed_plan)
1414}
1415
1416impl PushDownFilter {
1417 #[expect(missing_docs)]
1418 pub fn new() -> Self {
1419 Self {}
1420 }
1421}
1422
1423fn with_debug_timing<T, F>(label: &'static str, f: F) -> Result<T>
1424where
1425 F: FnOnce() -> Result<T>,
1426{
1427 if !log_enabled!(Level::Debug) {
1428 return f();
1429 }
1430 let start = Instant::now();
1431 let result = f();
1432 debug!(
1433 "push_down_filter_timing: section={label}, elapsed_us={}",
1434 start.elapsed().as_micros()
1435 );
1436 result
1437}
1438
1439pub fn replace_cols_by_name(
1441 e: Expr,
1442 replace_map: &HashMap<String, Expr>,
1443) -> Result<Expr> {
1444 e.transform_up(|expr| {
1445 Ok(if let Expr::Column(c) = &expr {
1446 match replace_map.get(&c.flat_name()) {
1447 Some(new_c) => Transformed::yes(new_c.clone()),
1448 None => Transformed::no(expr),
1449 }
1450 } else {
1451 Transformed::no(expr)
1452 })
1453 })
1454 .data()
1455}
1456
1457fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1459 let mut is_contain = false;
1460 e.apply(|expr| {
1461 Ok(if let Expr::Column(c) = &expr {
1462 match check_map.get(&c.flat_name()) {
1463 Some(_) => {
1464 is_contain = true;
1465 TreeNodeRecursion::Stop
1466 }
1467 None => TreeNodeRecursion::Continue,
1468 }
1469 } else {
1470 TreeNodeRecursion::Continue
1471 })
1472 })
1473 .unwrap();
1474 is_contain
1475}
1476
1477#[cfg(test)]
1478mod tests {
1479 use std::cmp::Ordering;
1480 use std::fmt::{Debug, Formatter};
1481
1482 use arrow::datatypes::{Field, Schema, SchemaRef};
1483 use async_trait::async_trait;
1484
1485 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1486 use datafusion_expr::expr::ScalarFunction;
1487 use datafusion_expr::logical_plan::table_scan;
1488 use datafusion_expr::{
1489 ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1490 ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1491 UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1492 in_subquery, lit,
1493 };
1494
1495 use crate::OptimizerContext;
1496 use crate::assert_optimized_plan_eq_snapshot;
1497 use crate::optimizer::Optimizer;
1498 use crate::simplify_expressions::SimplifyExpressions;
1499 use crate::test::udfs::leaf_udf_expr;
1500 use crate::test::*;
1501 use datafusion_expr::test::function_stub::sum;
1502 use insta::assert_snapshot;
1503
1504 use super::*;
1505
1506 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1507
1508 macro_rules! assert_optimized_plan_equal {
1509 (
1510 $plan:expr,
1511 @ $expected:literal $(,)?
1512 ) => {{
1513 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1514 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1515 assert_optimized_plan_eq_snapshot!(
1516 optimizer_ctx,
1517 rules,
1518 $plan,
1519 @ $expected,
1520 )
1521 }};
1522 }
1523
1524 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1525 (
1526 $plan:expr,
1527 @ $expected:literal $(,)?
1528 ) => {{
1529 let optimizer = Optimizer::with_rules(vec![
1530 Arc::new(SimplifyExpressions::new()),
1531 Arc::new(PushDownFilter::new()),
1532 ]);
1533 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1534 assert_snapshot!(optimized_plan, @ $expected);
1535 Ok::<(), DataFusionError>(())
1536 }};
1537 }
1538
1539 #[test]
1540 fn filter_before_projection() -> Result<()> {
1541 let table_scan = test_table_scan()?;
1542 let plan = LogicalPlanBuilder::from(table_scan)
1543 .project(vec![col("a"), col("b")])?
1544 .filter(col("a").eq(lit(1i64)))?
1545 .build()?;
1546 assert_optimized_plan_equal!(
1548 plan,
1549 @r"
1550 Projection: test.a, test.b
1551 TableScan: test, full_filters=[test.a = Int64(1)]
1552 "
1553 )
1554 }
1555
1556 #[test]
1557 fn filter_after_limit() -> Result<()> {
1558 let table_scan = test_table_scan()?;
1559 let plan = LogicalPlanBuilder::from(table_scan)
1560 .project(vec![col("a"), col("b")])?
1561 .limit(0, Some(10))?
1562 .filter(col("a").eq(lit(1i64)))?
1563 .build()?;
1564 assert_optimized_plan_equal!(
1566 plan,
1567 @r"
1568 Filter: test.a = Int64(1)
1569 Limit: skip=0, fetch=10
1570 Projection: test.a, test.b
1571 TableScan: test
1572 "
1573 )
1574 }
1575
1576 #[test]
1577 fn filter_no_columns() -> Result<()> {
1578 let table_scan = test_table_scan()?;
1579 let plan = LogicalPlanBuilder::from(table_scan)
1580 .filter(lit(0i64).eq(lit(1i64)))?
1581 .build()?;
1582 assert_optimized_plan_equal!(
1583 plan,
1584 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1585 )
1586 }
1587
1588 #[test]
1589 fn filter_jump_2_plans() -> Result<()> {
1590 let table_scan = test_table_scan()?;
1591 let plan = LogicalPlanBuilder::from(table_scan)
1592 .project(vec![col("a"), col("b"), col("c")])?
1593 .project(vec![col("c"), col("b")])?
1594 .filter(col("a").eq(lit(1i64)))?
1595 .build()?;
1596 assert_optimized_plan_equal!(
1598 plan,
1599 @r"
1600 Projection: test.c, test.b
1601 Projection: test.a, test.b, test.c
1602 TableScan: test, full_filters=[test.a = Int64(1)]
1603 "
1604 )
1605 }
1606
1607 #[test]
1608 fn filter_move_agg() -> Result<()> {
1609 let table_scan = test_table_scan()?;
1610 let plan = LogicalPlanBuilder::from(table_scan)
1611 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1612 .filter(col("a").gt(lit(10i64)))?
1613 .build()?;
1614 assert_optimized_plan_equal!(
1616 plan,
1617 @r"
1618 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1619 TableScan: test, full_filters=[test.a > Int64(10)]
1620 "
1621 )
1622 }
1623
1624 #[test]
1626 fn filter_move_agg_special() -> Result<()> {
1627 let schema = Schema::new(vec![
1628 Field::new("$a", DataType::UInt32, false),
1629 Field::new("$b", DataType::UInt32, false),
1630 Field::new("$c", DataType::UInt32, false),
1631 ]);
1632 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1633
1634 let plan = LogicalPlanBuilder::from(table_scan)
1635 .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1636 .filter(col("$a").gt(lit(10i64)))?
1637 .build()?;
1638 assert_optimized_plan_equal!(
1640 plan,
1641 @r"
1642 Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1643 TableScan: test, full_filters=[test.$a > Int64(10)]
1644 "
1645 )
1646 }
1647
1648 #[test]
1649 fn filter_complex_group_by() -> Result<()> {
1650 let table_scan = test_table_scan()?;
1651 let plan = LogicalPlanBuilder::from(table_scan)
1652 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1653 .filter(col("b").gt(lit(10i64)))?
1654 .build()?;
1655 assert_optimized_plan_equal!(
1656 plan,
1657 @r"
1658 Filter: test.b > Int64(10)
1659 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1660 TableScan: test
1661 "
1662 )
1663 }
1664
1665 #[test]
1666 fn push_agg_need_replace_expr() -> Result<()> {
1667 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1668 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1669 .filter(col("test.b + test.a").gt(lit(10i64)))?
1670 .build()?;
1671 assert_optimized_plan_equal!(
1672 plan,
1673 @r"
1674 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1675 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1676 "
1677 )
1678 }
1679
1680 #[test]
1681 fn filter_keep_agg() -> Result<()> {
1682 let table_scan = test_table_scan()?;
1683 let plan = LogicalPlanBuilder::from(table_scan)
1684 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1685 .filter(col("b").gt(lit(10i64)))?
1686 .build()?;
1687 assert_optimized_plan_equal!(
1689 plan,
1690 @r"
1691 Filter: b > Int64(10)
1692 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1693 TableScan: test
1694 "
1695 )
1696 }
1697
1698 #[test]
1700 fn filter_move_window() -> Result<()> {
1701 let table_scan = test_table_scan()?;
1702
1703 let window = Expr::from(WindowFunction::new(
1704 WindowFunctionDefinition::WindowUDF(
1705 datafusion_functions_window::rank::rank_udwf(),
1706 ),
1707 vec![],
1708 ))
1709 .partition_by(vec![col("a"), col("b")])
1710 .order_by(vec![col("c").sort(true, true)])
1711 .build()
1712 .unwrap();
1713
1714 let plan = LogicalPlanBuilder::from(table_scan)
1715 .window(vec![window])?
1716 .filter(col("b").gt(lit(10i64)))?
1717 .build()?;
1718
1719 assert_optimized_plan_equal!(
1720 plan,
1721 @r"
1722 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1723 TableScan: test, full_filters=[test.b > Int64(10)]
1724 "
1725 )
1726 }
1727
1728 #[test]
1730 fn filter_window_special_identifier() -> Result<()> {
1731 let schema = Schema::new(vec![
1732 Field::new("$a", DataType::UInt32, false),
1733 Field::new("$b", DataType::UInt32, false),
1734 Field::new("$c", DataType::UInt32, false),
1735 ]);
1736 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1737
1738 let window = Expr::from(WindowFunction::new(
1739 WindowFunctionDefinition::WindowUDF(
1740 datafusion_functions_window::rank::rank_udwf(),
1741 ),
1742 vec![],
1743 ))
1744 .partition_by(vec![col("$a"), col("$b")])
1745 .order_by(vec![col("$c").sort(true, true)])
1746 .build()
1747 .unwrap();
1748
1749 let plan = LogicalPlanBuilder::from(table_scan)
1750 .window(vec![window])?
1751 .filter(col("$b").gt(lit(10i64)))?
1752 .build()?;
1753
1754 assert_optimized_plan_equal!(
1755 plan,
1756 @r"
1757 WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1758 TableScan: test, full_filters=[test.$b > Int64(10)]
1759 "
1760 )
1761 }
1762
1763 #[test]
1766 fn filter_move_complex_window() -> Result<()> {
1767 let table_scan = test_table_scan()?;
1768
1769 let window = Expr::from(WindowFunction::new(
1770 WindowFunctionDefinition::WindowUDF(
1771 datafusion_functions_window::rank::rank_udwf(),
1772 ),
1773 vec![],
1774 ))
1775 .partition_by(vec![col("a"), col("b")])
1776 .order_by(vec![col("c").sort(true, true)])
1777 .build()
1778 .unwrap();
1779
1780 let plan = LogicalPlanBuilder::from(table_scan)
1781 .window(vec![window])?
1782 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1783 .build()?;
1784
1785 assert_optimized_plan_equal!(
1786 plan,
1787 @r"
1788 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1789 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1790 "
1791 )
1792 }
1793
1794 #[test]
1796 fn filter_move_partial_window() -> Result<()> {
1797 let table_scan = test_table_scan()?;
1798
1799 let window = Expr::from(WindowFunction::new(
1800 WindowFunctionDefinition::WindowUDF(
1801 datafusion_functions_window::rank::rank_udwf(),
1802 ),
1803 vec![],
1804 ))
1805 .partition_by(vec![col("a")])
1806 .order_by(vec![col("c").sort(true, true)])
1807 .build()
1808 .unwrap();
1809
1810 let plan = LogicalPlanBuilder::from(table_scan)
1811 .window(vec![window])?
1812 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1813 .build()?;
1814
1815 assert_optimized_plan_equal!(
1816 plan,
1817 @r"
1818 Filter: test.b = Int64(1)
1819 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1820 TableScan: test, full_filters=[test.a > Int64(10)]
1821 "
1822 )
1823 }
1824
1825 #[test]
1828 fn filter_expression_keep_window() -> Result<()> {
1829 let table_scan = test_table_scan()?;
1830
1831 let window = Expr::from(WindowFunction::new(
1832 WindowFunctionDefinition::WindowUDF(
1833 datafusion_functions_window::rank::rank_udwf(),
1834 ),
1835 vec![],
1836 ))
1837 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1839 .build()
1840 .unwrap();
1841
1842 let plan = LogicalPlanBuilder::from(table_scan)
1843 .window(vec![window])?
1844 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1847 .build()?;
1848
1849 assert_optimized_plan_equal!(
1850 plan,
1851 @r"
1852 Filter: test.a + test.b > Int64(10)
1853 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1854 TableScan: test
1855 "
1856 )
1857 }
1858
1859 #[test]
1861 fn filter_order_keep_window() -> Result<()> {
1862 let table_scan = test_table_scan()?;
1863
1864 let window = Expr::from(WindowFunction::new(
1865 WindowFunctionDefinition::WindowUDF(
1866 datafusion_functions_window::rank::rank_udwf(),
1867 ),
1868 vec![],
1869 ))
1870 .partition_by(vec![col("a")])
1871 .order_by(vec![col("c").sort(true, true)])
1872 .build()
1873 .unwrap();
1874
1875 let plan = LogicalPlanBuilder::from(table_scan)
1876 .window(vec![window])?
1877 .filter(col("c").gt(lit(10i64)))?
1878 .build()?;
1879
1880 assert_optimized_plan_equal!(
1881 plan,
1882 @r"
1883 Filter: test.c > Int64(10)
1884 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1885 TableScan: test
1886 "
1887 )
1888 }
1889
1890 #[test]
1893 fn filter_multiple_windows_common_partitions() -> Result<()> {
1894 let table_scan = test_table_scan()?;
1895
1896 let window1 = Expr::from(WindowFunction::new(
1897 WindowFunctionDefinition::WindowUDF(
1898 datafusion_functions_window::rank::rank_udwf(),
1899 ),
1900 vec![],
1901 ))
1902 .partition_by(vec![col("a")])
1903 .order_by(vec![col("c").sort(true, true)])
1904 .build()
1905 .unwrap();
1906
1907 let window2 = Expr::from(WindowFunction::new(
1908 WindowFunctionDefinition::WindowUDF(
1909 datafusion_functions_window::rank::rank_udwf(),
1910 ),
1911 vec![],
1912 ))
1913 .partition_by(vec![col("b"), col("a")])
1914 .order_by(vec![col("c").sort(true, true)])
1915 .build()
1916 .unwrap();
1917
1918 let plan = LogicalPlanBuilder::from(table_scan)
1919 .window(vec![window1, window2])?
1920 .filter(col("a").gt(lit(10i64)))? .build()?;
1922
1923 assert_optimized_plan_equal!(
1924 plan,
1925 @r"
1926 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1927 TableScan: test, full_filters=[test.a > Int64(10)]
1928 "
1929 )
1930 }
1931
1932 #[test]
1935 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1936 let table_scan = test_table_scan()?;
1937
1938 let window1 = Expr::from(WindowFunction::new(
1939 WindowFunctionDefinition::WindowUDF(
1940 datafusion_functions_window::rank::rank_udwf(),
1941 ),
1942 vec![],
1943 ))
1944 .partition_by(vec![col("a")])
1945 .order_by(vec![col("c").sort(true, true)])
1946 .build()
1947 .unwrap();
1948
1949 let window2 = Expr::from(WindowFunction::new(
1950 WindowFunctionDefinition::WindowUDF(
1951 datafusion_functions_window::rank::rank_udwf(),
1952 ),
1953 vec![],
1954 ))
1955 .partition_by(vec![col("b"), col("a")])
1956 .order_by(vec![col("c").sort(true, true)])
1957 .build()
1958 .unwrap();
1959
1960 let plan = LogicalPlanBuilder::from(table_scan)
1961 .window(vec![window1, window2])?
1962 .filter(col("b").gt(lit(10i64)))? .build()?;
1964
1965 assert_optimized_plan_equal!(
1966 plan,
1967 @r"
1968 Filter: test.b > Int64(10)
1969 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1970 TableScan: test
1971 "
1972 )
1973 }
1974
1975 #[test]
1977 fn alias() -> Result<()> {
1978 let table_scan = test_table_scan()?;
1979 let plan = LogicalPlanBuilder::from(table_scan)
1980 .project(vec![col("a").alias("b"), col("c")])?
1981 .filter(col("b").eq(lit(1i64)))?
1982 .build()?;
1983 assert_optimized_plan_equal!(
1985 plan,
1986 @r"
1987 Projection: test.a AS b, test.c
1988 TableScan: test, full_filters=[test.a = Int64(1)]
1989 "
1990 )
1991 }
1992
1993 fn add(left: Expr, right: Expr) -> Expr {
1994 Expr::BinaryExpr(BinaryExpr::new(
1995 Box::new(left),
1996 Operator::Plus,
1997 Box::new(right),
1998 ))
1999 }
2000
2001 fn multiply(left: Expr, right: Expr) -> Expr {
2002 Expr::BinaryExpr(BinaryExpr::new(
2003 Box::new(left),
2004 Operator::Multiply,
2005 Box::new(right),
2006 ))
2007 }
2008
2009 #[test]
2011 fn complex_expression() -> Result<()> {
2012 let table_scan = test_table_scan()?;
2013 let plan = LogicalPlanBuilder::from(table_scan)
2014 .project(vec![
2015 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2016 col("c"),
2017 ])?
2018 .filter(col("b").eq(lit(1i64)))?
2019 .build()?;
2020
2021 assert_snapshot!(plan,
2023 @r"
2024 Filter: b = Int64(1)
2025 Projection: test.a * Int32(2) + test.c AS b, test.c
2026 TableScan: test
2027 ",
2028 );
2029 assert_optimized_plan_equal!(
2031 plan,
2032 @r"
2033 Projection: test.a * Int32(2) + test.c AS b, test.c
2034 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
2035 "
2036 )
2037 }
2038
2039 #[test]
2041 fn complex_plan() -> Result<()> {
2042 let table_scan = test_table_scan()?;
2043 let plan = LogicalPlanBuilder::from(table_scan)
2044 .project(vec![
2045 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2046 col("c"),
2047 ])?
2048 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2050 .filter(col("a").eq(lit(1i64)))?
2051 .build()?;
2052
2053 assert_snapshot!(plan,
2055 @r"
2056 Filter: a = Int64(1)
2057 Projection: b * Int32(3) AS a, test.c
2058 Projection: test.a * Int32(2) + test.c AS b, test.c
2059 TableScan: test
2060 ",
2061 );
2062 assert_optimized_plan_equal!(
2064 plan,
2065 @r"
2066 Projection: b * Int32(3) AS a, test.c
2067 Projection: test.a * Int32(2) + test.c AS b, test.c
2068 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2069 "
2070 )
2071 }
2072
2073 #[derive(Debug, PartialEq, Eq, Hash)]
2074 struct NoopPlan {
2075 input: Vec<LogicalPlan>,
2076 schema: DFSchemaRef,
2077 }
2078
2079 impl PartialOrd for NoopPlan {
2081 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2082 self.input
2083 .partial_cmp(&other.input)
2084 .filter(|cmp| *cmp != Ordering::Equal || self == other)
2086 }
2087 }
2088
2089 impl UserDefinedLogicalNodeCore for NoopPlan {
2090 fn name(&self) -> &str {
2091 "NoopPlan"
2092 }
2093
2094 fn inputs(&self) -> Vec<&LogicalPlan> {
2095 self.input.iter().collect()
2096 }
2097
2098 fn schema(&self) -> &DFSchemaRef {
2099 &self.schema
2100 }
2101
2102 fn expressions(&self) -> Vec<Expr> {
2103 self.input
2104 .iter()
2105 .flat_map(|child| child.expressions())
2106 .collect()
2107 }
2108
2109 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2110 HashSet::from_iter(vec!["c".to_string()])
2111 }
2112
2113 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2114 write!(f, "NoopPlan")
2115 }
2116
2117 fn with_exprs_and_inputs(
2118 &self,
2119 _exprs: Vec<Expr>,
2120 inputs: Vec<LogicalPlan>,
2121 ) -> Result<Self> {
2122 Ok(Self {
2123 input: inputs,
2124 schema: Arc::clone(&self.schema),
2125 })
2126 }
2127
2128 fn supports_limit_pushdown(&self) -> bool {
2129 false }
2131 }
2132
2133 #[test]
2134 fn user_defined_plan() -> Result<()> {
2135 let table_scan = test_table_scan()?;
2136
2137 let custom_plan = LogicalPlan::Extension(Extension {
2138 node: Arc::new(NoopPlan {
2139 input: vec![table_scan.clone()],
2140 schema: Arc::clone(table_scan.schema()),
2141 }),
2142 });
2143 let plan = LogicalPlanBuilder::from(custom_plan)
2144 .filter(col("a").eq(lit(1i64)))?
2145 .build()?;
2146
2147 assert_optimized_plan_equal!(
2149 plan,
2150 @r"
2151 NoopPlan
2152 TableScan: test, full_filters=[test.a = Int64(1)]
2153 "
2154 )?;
2155
2156 let custom_plan = LogicalPlan::Extension(Extension {
2157 node: Arc::new(NoopPlan {
2158 input: vec![table_scan.clone()],
2159 schema: Arc::clone(table_scan.schema()),
2160 }),
2161 });
2162 let plan = LogicalPlanBuilder::from(custom_plan)
2163 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2164 .build()?;
2165
2166 assert_optimized_plan_equal!(
2168 plan,
2169 @r"
2170 Filter: test.c = Int64(2)
2171 NoopPlan
2172 TableScan: test, full_filters=[test.a = Int64(1)]
2173 "
2174 )?;
2175
2176 let custom_plan = LogicalPlan::Extension(Extension {
2177 node: Arc::new(NoopPlan {
2178 input: vec![table_scan.clone(), table_scan.clone()],
2179 schema: Arc::clone(table_scan.schema()),
2180 }),
2181 });
2182 let plan = LogicalPlanBuilder::from(custom_plan)
2183 .filter(col("a").eq(lit(1i64)))?
2184 .build()?;
2185
2186 assert_optimized_plan_equal!(
2188 plan,
2189 @r"
2190 NoopPlan
2191 TableScan: test, full_filters=[test.a = Int64(1)]
2192 TableScan: test, full_filters=[test.a = Int64(1)]
2193 "
2194 )?;
2195
2196 let custom_plan = LogicalPlan::Extension(Extension {
2197 node: Arc::new(NoopPlan {
2198 input: vec![table_scan.clone(), table_scan.clone()],
2199 schema: Arc::clone(table_scan.schema()),
2200 }),
2201 });
2202 let plan = LogicalPlanBuilder::from(custom_plan)
2203 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2204 .build()?;
2205
2206 assert_optimized_plan_equal!(
2208 plan,
2209 @r"
2210 Filter: test.c = Int64(2)
2211 NoopPlan
2212 TableScan: test, full_filters=[test.a = Int64(1)]
2213 TableScan: test, full_filters=[test.a = Int64(1)]
2214 "
2215 )
2216 }
2217
2218 #[test]
2221 fn multi_filter() -> Result<()> {
2222 let table_scan = test_table_scan()?;
2224 let plan = LogicalPlanBuilder::from(table_scan)
2225 .project(vec![col("a").alias("b"), col("c")])?
2226 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2227 .filter(col("b").gt(lit(10i64)))?
2228 .filter(col("sum(test.c)").gt(lit(10i64)))?
2229 .build()?;
2230
2231 assert_snapshot!(plan,
2233 @r"
2234 Filter: sum(test.c) > Int64(10)
2235 Filter: b > Int64(10)
2236 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2237 Projection: test.a AS b, test.c
2238 TableScan: test
2239 ",
2240 );
2241 assert_optimized_plan_equal!(
2243 plan,
2244 @r"
2245 Filter: sum(test.c) > Int64(10)
2246 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2247 Projection: test.a AS b, test.c
2248 TableScan: test, full_filters=[test.a > Int64(10)]
2249 "
2250 )
2251 }
2252
2253 #[test]
2256 fn split_filter() -> Result<()> {
2257 let table_scan = test_table_scan()?;
2259 let plan = LogicalPlanBuilder::from(table_scan)
2260 .project(vec![col("a").alias("b"), col("c")])?
2261 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2262 .filter(and(
2263 col("sum(test.c)").gt(lit(10i64)),
2264 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2265 ))?
2266 .build()?;
2267
2268 assert_snapshot!(plan,
2270 @r"
2271 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2272 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2273 Projection: test.a AS b, test.c
2274 TableScan: test
2275 ",
2276 );
2277 assert_optimized_plan_equal!(
2279 plan,
2280 @r"
2281 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2282 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2283 Projection: test.a AS b, test.c
2284 TableScan: test, full_filters=[test.a > Int64(10)]
2285 "
2286 )
2287 }
2288
2289 #[test]
2291 fn double_limit() -> Result<()> {
2292 let table_scan = test_table_scan()?;
2293 let plan = LogicalPlanBuilder::from(table_scan)
2294 .project(vec![col("a"), col("b")])?
2295 .limit(0, Some(20))?
2296 .limit(0, Some(10))?
2297 .project(vec![col("a"), col("b")])?
2298 .filter(col("a").eq(lit(1i64)))?
2299 .build()?;
2300 assert_optimized_plan_equal!(
2302 plan,
2303 @r"
2304 Projection: test.a, test.b
2305 Filter: test.a = Int64(1)
2306 Limit: skip=0, fetch=10
2307 Limit: skip=0, fetch=20
2308 Projection: test.a, test.b
2309 TableScan: test
2310 "
2311 )
2312 }
2313
2314 #[test]
2315 fn union_all() -> Result<()> {
2316 let table_scan = test_table_scan()?;
2317 let table_scan2 = test_table_scan_with_name("test2")?;
2318 let plan = LogicalPlanBuilder::from(table_scan)
2319 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2320 .filter(col("a").eq(lit(1i64)))?
2321 .build()?;
2322 assert_optimized_plan_equal!(
2324 plan,
2325 @r"
2326 Union
2327 TableScan: test, full_filters=[test.a = Int64(1)]
2328 TableScan: test2, full_filters=[test2.a = Int64(1)]
2329 "
2330 )
2331 }
2332
2333 #[test]
2334 fn union_all_on_projection() -> Result<()> {
2335 let table_scan = test_table_scan()?;
2336 let table = LogicalPlanBuilder::from(table_scan)
2337 .project(vec![col("a").alias("b")])?
2338 .alias("test2")?;
2339
2340 let plan = table
2341 .clone()
2342 .union(table.build()?)?
2343 .filter(col("b").eq(lit(1i64)))?
2344 .build()?;
2345
2346 assert_optimized_plan_equal!(
2348 plan,
2349 @r"
2350 Union
2351 SubqueryAlias: test2
2352 Projection: test.a AS b
2353 TableScan: test, full_filters=[test.a = Int64(1)]
2354 SubqueryAlias: test2
2355 Projection: test.a AS b
2356 TableScan: test, full_filters=[test.a = Int64(1)]
2357 "
2358 )
2359 }
2360
2361 #[test]
2362 fn test_union_different_schema() -> Result<()> {
2363 let left = LogicalPlanBuilder::from(test_table_scan()?)
2364 .project(vec![col("a"), col("b"), col("c")])?
2365 .build()?;
2366
2367 let schema = Schema::new(vec![
2368 Field::new("d", DataType::UInt32, false),
2369 Field::new("e", DataType::UInt32, false),
2370 Field::new("f", DataType::UInt32, false),
2371 ]);
2372 let right = table_scan(Some("test1"), &schema, None)?
2373 .project(vec![col("d"), col("e"), col("f")])?
2374 .build()?;
2375 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2376 let plan = LogicalPlanBuilder::from(left)
2377 .cross_join(right)?
2378 .project(vec![col("test.a"), col("test1.d")])?
2379 .filter(filter)?
2380 .build()?;
2381
2382 assert_optimized_plan_equal!(
2383 plan,
2384 @r"
2385 Projection: test.a, test1.d
2386 Cross Join:
2387 Projection: test.a, test.b, test.c
2388 TableScan: test, full_filters=[test.a = Int32(1)]
2389 Projection: test1.d, test1.e, test1.f
2390 TableScan: test1, full_filters=[test1.d > Int32(2)]
2391 "
2392 )
2393 }
2394
2395 #[test]
2396 fn test_project_same_name_different_qualifier() -> Result<()> {
2397 let table_scan = test_table_scan()?;
2398 let left = LogicalPlanBuilder::from(table_scan)
2399 .project(vec![col("a"), col("b"), col("c")])?
2400 .build()?;
2401 let right_table_scan = test_table_scan_with_name("test1")?;
2402 let right = LogicalPlanBuilder::from(right_table_scan)
2403 .project(vec![col("a"), col("b"), col("c")])?
2404 .build()?;
2405 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2406 let plan = LogicalPlanBuilder::from(left)
2407 .cross_join(right)?
2408 .project(vec![col("test.a"), col("test1.a")])?
2409 .filter(filter)?
2410 .build()?;
2411
2412 assert_optimized_plan_equal!(
2413 plan,
2414 @r"
2415 Projection: test.a, test1.a
2416 Cross Join:
2417 Projection: test.a, test.b, test.c
2418 TableScan: test, full_filters=[test.a = Int32(1)]
2419 Projection: test1.a, test1.b, test1.c
2420 TableScan: test1, full_filters=[test1.a > Int32(2)]
2421 "
2422 )
2423 }
2424
2425 #[test]
2427 fn filter_2_breaks_limits() -> Result<()> {
2428 let table_scan = test_table_scan()?;
2429 let plan = LogicalPlanBuilder::from(table_scan)
2430 .project(vec![col("a")])?
2431 .filter(col("a").lt_eq(lit(1i64)))?
2432 .limit(0, Some(1))?
2433 .project(vec![col("a")])?
2434 .filter(col("a").gt_eq(lit(1i64)))?
2435 .build()?;
2436 assert_snapshot!(plan,
2440 @r"
2441 Filter: test.a >= Int64(1)
2442 Projection: test.a
2443 Limit: skip=0, fetch=1
2444 Filter: test.a <= Int64(1)
2445 Projection: test.a
2446 TableScan: test
2447 ",
2448 );
2449 assert_optimized_plan_equal!(
2450 plan,
2451 @r"
2452 Projection: test.a
2453 Filter: test.a >= Int64(1)
2454 Limit: skip=0, fetch=1
2455 Projection: test.a
2456 TableScan: test, full_filters=[test.a <= Int64(1)]
2457 "
2458 )
2459 }
2460
2461 #[test]
2463 fn two_filters_on_same_depth() -> Result<()> {
2464 let table_scan = test_table_scan()?;
2465 let plan = LogicalPlanBuilder::from(table_scan)
2466 .limit(0, Some(1))?
2467 .filter(col("a").lt_eq(lit(1i64)))?
2468 .filter(col("a").gt_eq(lit(1i64)))?
2469 .project(vec![col("a")])?
2470 .build()?;
2471
2472 assert_snapshot!(plan,
2474 @r"
2475 Projection: test.a
2476 Filter: test.a >= Int64(1)
2477 Filter: test.a <= Int64(1)
2478 Limit: skip=0, fetch=1
2479 TableScan: test
2480 ",
2481 );
2482 assert_optimized_plan_equal!(
2483 plan,
2484 @r"
2485 Projection: test.a
2486 Filter: test.a <= Int64(1) AND test.a >= Int64(1)
2487 Limit: skip=0, fetch=1
2488 TableScan: test
2489 "
2490 )
2491 }
2492
2493 #[test]
2496 fn filters_user_defined_node() -> Result<()> {
2497 let table_scan = test_table_scan()?;
2498 let plan = LogicalPlanBuilder::from(table_scan)
2499 .filter(col("a").lt_eq(lit(1i64)))?
2500 .build()?;
2501
2502 let plan = user_defined::new(plan);
2503
2504 assert_snapshot!(plan,
2506 @r"
2507 TestUserDefined
2508 Filter: test.a <= Int64(1)
2509 TableScan: test
2510 ",
2511 );
2512 assert_optimized_plan_equal!(
2513 plan,
2514 @r"
2515 TestUserDefined
2516 TableScan: test, full_filters=[test.a <= Int64(1)]
2517 "
2518 )
2519 }
2520
2521 #[test]
2523 fn filter_on_join_on_common_independent() -> Result<()> {
2524 let table_scan = test_table_scan()?;
2525 let left = LogicalPlanBuilder::from(table_scan).build()?;
2526 let right_table_scan = test_table_scan_with_name("test2")?;
2527 let right = LogicalPlanBuilder::from(right_table_scan)
2528 .project(vec![col("a")])?
2529 .build()?;
2530 let plan = LogicalPlanBuilder::from(left)
2531 .join(
2532 right,
2533 JoinType::Inner,
2534 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2535 None,
2536 )?
2537 .filter(col("test.a").lt_eq(lit(1i64)))?
2538 .build()?;
2539
2540 assert_snapshot!(plan,
2542 @r"
2543 Filter: test.a <= Int64(1)
2544 Inner Join: test.a = test2.a
2545 TableScan: test
2546 Projection: test2.a
2547 TableScan: test2
2548 ",
2549 );
2550 assert_optimized_plan_equal!(
2552 plan,
2553 @r"
2554 Inner Join: test.a = test2.a
2555 TableScan: test, full_filters=[test.a <= Int64(1)]
2556 Projection: test2.a
2557 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2558 "
2559 )
2560 }
2561
2562 #[test]
2564 fn filter_using_join_on_common_independent() -> Result<()> {
2565 let table_scan = test_table_scan()?;
2566 let left = LogicalPlanBuilder::from(table_scan).build()?;
2567 let right_table_scan = test_table_scan_with_name("test2")?;
2568 let right = LogicalPlanBuilder::from(right_table_scan)
2569 .project(vec![col("a")])?
2570 .build()?;
2571 let plan = LogicalPlanBuilder::from(left)
2572 .join_using(
2573 right,
2574 JoinType::Inner,
2575 vec![Column::from_name("a".to_string())],
2576 )?
2577 .filter(col("a").lt_eq(lit(1i64)))?
2578 .build()?;
2579
2580 assert_snapshot!(plan,
2582 @r"
2583 Filter: test.a <= Int64(1)
2584 Inner Join: Using test.a = test2.a
2585 TableScan: test
2586 Projection: test2.a
2587 TableScan: test2
2588 ",
2589 );
2590 assert_optimized_plan_equal!(
2592 plan,
2593 @r"
2594 Inner Join: Using test.a = test2.a
2595 TableScan: test, full_filters=[test.a <= Int64(1)]
2596 Projection: test2.a
2597 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2598 "
2599 )
2600 }
2601
2602 #[test]
2604 fn filter_join_on_common_dependent() -> Result<()> {
2605 let table_scan = test_table_scan()?;
2606 let left = LogicalPlanBuilder::from(table_scan)
2607 .project(vec![col("a"), col("c")])?
2608 .build()?;
2609 let right_table_scan = test_table_scan_with_name("test2")?;
2610 let right = LogicalPlanBuilder::from(right_table_scan)
2611 .project(vec![col("a"), col("b")])?
2612 .build()?;
2613 let plan = LogicalPlanBuilder::from(left)
2614 .join(
2615 right,
2616 JoinType::Inner,
2617 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2618 None,
2619 )?
2620 .filter(col("c").lt_eq(col("b")))?
2621 .build()?;
2622
2623 assert_snapshot!(plan,
2625 @r"
2626 Filter: test.c <= test2.b
2627 Inner Join: test.a = test2.a
2628 Projection: test.a, test.c
2629 TableScan: test
2630 Projection: test2.a, test2.b
2631 TableScan: test2
2632 ",
2633 );
2634 assert_optimized_plan_equal!(
2636 plan,
2637 @r"
2638 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2639 Projection: test.a, test.c
2640 TableScan: test
2641 Projection: test2.a, test2.b
2642 TableScan: test2
2643 "
2644 )
2645 }
2646
2647 #[test]
2649 fn filter_join_on_one_side() -> Result<()> {
2650 let table_scan = test_table_scan()?;
2651 let left = LogicalPlanBuilder::from(table_scan)
2652 .project(vec![col("a"), col("b")])?
2653 .build()?;
2654 let table_scan_right = test_table_scan_with_name("test2")?;
2655 let right = LogicalPlanBuilder::from(table_scan_right)
2656 .project(vec![col("a"), col("c")])?
2657 .build()?;
2658
2659 let plan = LogicalPlanBuilder::from(left)
2660 .join(
2661 right,
2662 JoinType::Inner,
2663 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2664 None,
2665 )?
2666 .filter(col("b").lt_eq(lit(1i64)))?
2667 .build()?;
2668
2669 assert_snapshot!(plan,
2671 @r"
2672 Filter: test.b <= Int64(1)
2673 Inner Join: test.a = test2.a
2674 Projection: test.a, test.b
2675 TableScan: test
2676 Projection: test2.a, test2.c
2677 TableScan: test2
2678 ",
2679 );
2680 assert_optimized_plan_equal!(
2681 plan,
2682 @r"
2683 Inner Join: test.a = test2.a
2684 Projection: test.a, test.b
2685 TableScan: test, full_filters=[test.b <= Int64(1)]
2686 Projection: test2.a, test2.c
2687 TableScan: test2
2688 "
2689 )
2690 }
2691
2692 #[test]
2695 fn filter_using_left_join() -> Result<()> {
2696 let table_scan = test_table_scan()?;
2697 let left = LogicalPlanBuilder::from(table_scan).build()?;
2698 let right_table_scan = test_table_scan_with_name("test2")?;
2699 let right = LogicalPlanBuilder::from(right_table_scan)
2700 .project(vec![col("a")])?
2701 .build()?;
2702 let plan = LogicalPlanBuilder::from(left)
2703 .join_using(
2704 right,
2705 JoinType::Left,
2706 vec![Column::from_name("a".to_string())],
2707 )?
2708 .filter(col("test2.a").lt_eq(lit(1i64)))?
2709 .build()?;
2710
2711 assert_snapshot!(plan,
2713 @r"
2714 Filter: test2.a <= Int64(1)
2715 Left Join: Using test.a = test2.a
2716 TableScan: test
2717 Projection: test2.a
2718 TableScan: test2
2719 ",
2720 );
2721 assert_optimized_plan_equal!(
2723 plan,
2724 @r"
2725 Filter: test2.a <= Int64(1)
2726 Left Join: Using test.a = test2.a
2727 TableScan: test, full_filters=[test.a <= Int64(1)]
2728 Projection: test2.a
2729 TableScan: test2
2730 "
2731 )
2732 }
2733
2734 #[test]
2736 fn filter_using_right_join() -> Result<()> {
2737 let table_scan = test_table_scan()?;
2738 let left = LogicalPlanBuilder::from(table_scan).build()?;
2739 let right_table_scan = test_table_scan_with_name("test2")?;
2740 let right = LogicalPlanBuilder::from(right_table_scan)
2741 .project(vec![col("a")])?
2742 .build()?;
2743 let plan = LogicalPlanBuilder::from(left)
2744 .join_using(
2745 right,
2746 JoinType::Right,
2747 vec![Column::from_name("a".to_string())],
2748 )?
2749 .filter(col("test.a").lt_eq(lit(1i64)))?
2750 .build()?;
2751
2752 assert_snapshot!(plan,
2754 @r"
2755 Filter: test.a <= Int64(1)
2756 Right Join: Using test.a = test2.a
2757 TableScan: test
2758 Projection: test2.a
2759 TableScan: test2
2760 ",
2761 );
2762 assert_optimized_plan_equal!(
2764 plan,
2765 @r"
2766 Filter: test.a <= Int64(1)
2767 Right Join: Using test.a = test2.a
2768 TableScan: test
2769 Projection: test2.a
2770 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2771 "
2772 )
2773 }
2774
2775 #[test]
2777 fn filter_using_left_join_on_common() -> Result<()> {
2778 let table_scan = test_table_scan()?;
2779 let left = LogicalPlanBuilder::from(table_scan).build()?;
2780 let right_table_scan = test_table_scan_with_name("test2")?;
2781 let right = LogicalPlanBuilder::from(right_table_scan)
2782 .project(vec![col("a")])?
2783 .build()?;
2784 let plan = LogicalPlanBuilder::from(left)
2785 .join_using(
2786 right,
2787 JoinType::Left,
2788 vec![Column::from_name("a".to_string())],
2789 )?
2790 .filter(col("a").lt_eq(lit(1i64)))?
2791 .build()?;
2792
2793 assert_snapshot!(plan,
2795 @r"
2796 Filter: test.a <= Int64(1)
2797 Left Join: Using test.a = test2.a
2798 TableScan: test
2799 Projection: test2.a
2800 TableScan: test2
2801 ",
2802 );
2803 assert_optimized_plan_equal!(
2805 plan,
2806 @r"
2807 Left Join: Using test.a = test2.a
2808 TableScan: test, full_filters=[test.a <= Int64(1)]
2809 Projection: test2.a
2810 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2811 "
2812 )
2813 }
2814
2815 #[test]
2817 fn filter_using_right_join_on_common() -> Result<()> {
2818 let table_scan = test_table_scan()?;
2819 let left = LogicalPlanBuilder::from(table_scan).build()?;
2820 let right_table_scan = test_table_scan_with_name("test2")?;
2821 let right = LogicalPlanBuilder::from(right_table_scan)
2822 .project(vec![col("a")])?
2823 .build()?;
2824 let plan = LogicalPlanBuilder::from(left)
2825 .join_using(
2826 right,
2827 JoinType::Right,
2828 vec![Column::from_name("a".to_string())],
2829 )?
2830 .filter(col("test2.a").lt_eq(lit(1i64)))?
2831 .build()?;
2832
2833 assert_snapshot!(plan,
2835 @r"
2836 Filter: test2.a <= Int64(1)
2837 Right Join: Using test.a = test2.a
2838 TableScan: test
2839 Projection: test2.a
2840 TableScan: test2
2841 ",
2842 );
2843 assert_optimized_plan_equal!(
2845 plan,
2846 @r"
2847 Right Join: Using test.a = test2.a
2848 TableScan: test, full_filters=[test.a <= Int64(1)]
2849 Projection: test2.a
2850 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2851 "
2852 )
2853 }
2854
2855 #[test]
2857 fn join_on_with_filter() -> Result<()> {
2858 let table_scan = test_table_scan()?;
2859 let left = LogicalPlanBuilder::from(table_scan)
2860 .project(vec![col("a"), col("b"), col("c")])?
2861 .build()?;
2862 let right_table_scan = test_table_scan_with_name("test2")?;
2863 let right = LogicalPlanBuilder::from(right_table_scan)
2864 .project(vec![col("a"), col("b"), col("c")])?
2865 .build()?;
2866 let filter = col("test.c")
2867 .gt(lit(1u32))
2868 .and(col("test.b").lt(col("test2.b")))
2869 .and(col("test2.c").gt(lit(4u32)));
2870 let plan = LogicalPlanBuilder::from(left)
2871 .join(
2872 right,
2873 JoinType::Inner,
2874 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2875 Some(filter),
2876 )?
2877 .build()?;
2878
2879 assert_snapshot!(plan,
2881 @r"
2882 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2883 Projection: test.a, test.b, test.c
2884 TableScan: test
2885 Projection: test2.a, test2.b, test2.c
2886 TableScan: test2
2887 ",
2888 );
2889 assert_optimized_plan_equal!(
2890 plan,
2891 @r"
2892 Inner Join: test.a = test2.a Filter: test.b < test2.b
2893 Projection: test.a, test.b, test.c
2894 TableScan: test, full_filters=[test.c > UInt32(1)]
2895 Projection: test2.a, test2.b, test2.c
2896 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2897 "
2898 )
2899 }
2900
2901 #[test]
2903 fn join_filter_removed() -> Result<()> {
2904 let table_scan = test_table_scan()?;
2905 let left = LogicalPlanBuilder::from(table_scan)
2906 .project(vec![col("a"), col("b"), col("c")])?
2907 .build()?;
2908 let right_table_scan = test_table_scan_with_name("test2")?;
2909 let right = LogicalPlanBuilder::from(right_table_scan)
2910 .project(vec![col("a"), col("b"), col("c")])?
2911 .build()?;
2912 let filter = col("test.b")
2913 .gt(lit(1u32))
2914 .and(col("test2.c").gt(lit(4u32)));
2915 let plan = LogicalPlanBuilder::from(left)
2916 .join(
2917 right,
2918 JoinType::Inner,
2919 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2920 Some(filter),
2921 )?
2922 .build()?;
2923
2924 assert_snapshot!(plan,
2926 @r"
2927 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2928 Projection: test.a, test.b, test.c
2929 TableScan: test
2930 Projection: test2.a, test2.b, test2.c
2931 TableScan: test2
2932 ",
2933 );
2934 assert_optimized_plan_equal!(
2935 plan,
2936 @r"
2937 Inner Join: test.a = test2.a
2938 Projection: test.a, test.b, test.c
2939 TableScan: test, full_filters=[test.b > UInt32(1)]
2940 Projection: test2.a, test2.b, test2.c
2941 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2942 "
2943 )
2944 }
2945
2946 #[test]
2948 fn join_filter_on_common() -> Result<()> {
2949 let table_scan = test_table_scan()?;
2950 let left = LogicalPlanBuilder::from(table_scan)
2951 .project(vec![col("a")])?
2952 .build()?;
2953 let right_table_scan = test_table_scan_with_name("test2")?;
2954 let right = LogicalPlanBuilder::from(right_table_scan)
2955 .project(vec![col("b")])?
2956 .build()?;
2957 let filter = col("test.a").gt(lit(1u32));
2958 let plan = LogicalPlanBuilder::from(left)
2959 .join(
2960 right,
2961 JoinType::Inner,
2962 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2963 Some(filter),
2964 )?
2965 .build()?;
2966
2967 assert_snapshot!(plan,
2969 @r"
2970 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2971 Projection: test.a
2972 TableScan: test
2973 Projection: test2.b
2974 TableScan: test2
2975 ",
2976 );
2977 assert_optimized_plan_equal!(
2978 plan,
2979 @r"
2980 Inner Join: test.a = test2.b
2981 Projection: test.a
2982 TableScan: test, full_filters=[test.a > UInt32(1)]
2983 Projection: test2.b
2984 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2985 "
2986 )
2987 }
2988
2989 #[test]
2991 fn left_join_on_with_filter() -> Result<()> {
2992 let table_scan = test_table_scan()?;
2993 let left = LogicalPlanBuilder::from(table_scan)
2994 .project(vec![col("a"), col("b"), col("c")])?
2995 .build()?;
2996 let right_table_scan = test_table_scan_with_name("test2")?;
2997 let right = LogicalPlanBuilder::from(right_table_scan)
2998 .project(vec![col("a"), col("b"), col("c")])?
2999 .build()?;
3000 let filter = col("test.a")
3001 .gt(lit(1u32))
3002 .and(col("test.b").lt(col("test2.b")))
3003 .and(col("test2.c").gt(lit(4u32)));
3004 let plan = LogicalPlanBuilder::from(left)
3005 .join(
3006 right,
3007 JoinType::Left,
3008 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3009 Some(filter),
3010 )?
3011 .build()?;
3012
3013 assert_snapshot!(plan,
3015 @r"
3016 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3017 Projection: test.a, test.b, test.c
3018 TableScan: test
3019 Projection: test2.a, test2.b, test2.c
3020 TableScan: test2
3021 ",
3022 );
3023 assert_optimized_plan_equal!(
3024 plan,
3025 @r"
3026 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
3027 Projection: test.a, test.b, test.c
3028 TableScan: test
3029 Projection: test2.a, test2.b, test2.c
3030 TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)]
3031 "
3032 )
3033 }
3034
3035 #[test]
3037 fn right_join_on_with_filter() -> Result<()> {
3038 let table_scan = test_table_scan()?;
3039 let left = LogicalPlanBuilder::from(table_scan)
3040 .project(vec![col("a"), col("b"), col("c")])?
3041 .build()?;
3042 let right_table_scan = test_table_scan_with_name("test2")?;
3043 let right = LogicalPlanBuilder::from(right_table_scan)
3044 .project(vec![col("a"), col("b"), col("c")])?
3045 .build()?;
3046 let filter = col("test.a")
3047 .gt(lit(1u32))
3048 .and(col("test.b").lt(col("test2.b")))
3049 .and(col("test2.c").gt(lit(4u32)));
3050 let plan = LogicalPlanBuilder::from(left)
3051 .join(
3052 right,
3053 JoinType::Right,
3054 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3055 Some(filter),
3056 )?
3057 .build()?;
3058
3059 assert_snapshot!(plan,
3061 @r"
3062 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3063 Projection: test.a, test.b, test.c
3064 TableScan: test
3065 Projection: test2.a, test2.b, test2.c
3066 TableScan: test2
3067 ",
3068 );
3069 assert_optimized_plan_equal!(
3070 plan,
3071 @r"
3072 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3073 Projection: test.a, test.b, test.c
3074 TableScan: test, full_filters=[test.a > UInt32(1)]
3075 Projection: test2.a, test2.b, test2.c
3076 TableScan: test2
3077 "
3078 )
3079 }
3080
3081 #[test]
3083 fn full_join_on_with_filter() -> Result<()> {
3084 let table_scan = test_table_scan()?;
3085 let left = LogicalPlanBuilder::from(table_scan)
3086 .project(vec![col("a"), col("b"), col("c")])?
3087 .build()?;
3088 let right_table_scan = test_table_scan_with_name("test2")?;
3089 let right = LogicalPlanBuilder::from(right_table_scan)
3090 .project(vec![col("a"), col("b"), col("c")])?
3091 .build()?;
3092 let filter = col("test.a")
3093 .gt(lit(1u32))
3094 .and(col("test.b").lt(col("test2.b")))
3095 .and(col("test2.c").gt(lit(4u32)));
3096 let plan = LogicalPlanBuilder::from(left)
3097 .join(
3098 right,
3099 JoinType::Full,
3100 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3101 Some(filter),
3102 )?
3103 .build()?;
3104
3105 assert_snapshot!(plan,
3107 @r"
3108 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3109 Projection: test.a, test.b, test.c
3110 TableScan: test
3111 Projection: test2.a, test2.b, test2.c
3112 TableScan: test2
3113 ",
3114 );
3115 assert_optimized_plan_equal!(
3116 plan,
3117 @r"
3118 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3119 Projection: test.a, test.b, test.c
3120 TableScan: test
3121 Projection: test2.a, test2.b, test2.c
3122 TableScan: test2
3123 "
3124 )
3125 }
3126
3127 struct PushDownProvider {
3128 pub filter_support: TableProviderFilterPushDown,
3129 }
3130
3131 #[async_trait]
3132 impl TableSource for PushDownProvider {
3133 fn schema(&self) -> SchemaRef {
3134 Arc::new(Schema::new(vec![
3135 Field::new("a", DataType::Int32, true),
3136 Field::new("b", DataType::Int32, true),
3137 ]))
3138 }
3139
3140 fn table_type(&self) -> TableType {
3141 TableType::Base
3142 }
3143
3144 fn supports_filters_pushdown(
3145 &self,
3146 filters: &[&Expr],
3147 ) -> Result<Vec<TableProviderFilterPushDown>> {
3148 Ok((0..filters.len())
3149 .map(|_| self.filter_support.clone())
3150 .collect())
3151 }
3152 }
3153
3154 fn table_scan_with_pushdown_provider_builder(
3155 filter_support: TableProviderFilterPushDown,
3156 filters: Vec<Expr>,
3157 projection: Option<Vec<usize>>,
3158 ) -> Result<LogicalPlanBuilder> {
3159 let test_provider = PushDownProvider { filter_support };
3160
3161 let table_scan = LogicalPlan::TableScan(TableScan {
3162 table_name: "test".into(),
3163 filters,
3164 projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3165 projection,
3166 source: Arc::new(test_provider),
3167 fetch: None,
3168 });
3169
3170 Ok(LogicalPlanBuilder::from(table_scan))
3171 }
3172
3173 fn table_scan_with_pushdown_provider(
3174 filter_support: TableProviderFilterPushDown,
3175 ) -> Result<LogicalPlan> {
3176 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3177 .filter(col("a").eq(lit(1i64)))?
3178 .build()
3179 }
3180
3181 #[test]
3182 fn filter_with_table_provider_exact() -> Result<()> {
3183 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3184
3185 assert_optimized_plan_equal!(
3186 plan,
3187 @"TableScan: test, full_filters=[a = Int64(1)]"
3188 )
3189 }
3190
3191 #[test]
3192 fn filter_with_table_provider_inexact() -> Result<()> {
3193 let plan =
3194 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3195
3196 assert_optimized_plan_equal!(
3197 plan,
3198 @r"
3199 Filter: a = Int64(1)
3200 TableScan: test, partial_filters=[a = Int64(1)]
3201 "
3202 )
3203 }
3204
3205 #[test]
3206 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3207 let plan =
3208 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3209
3210 let optimized_plan = PushDownFilter::new()
3211 .rewrite(plan, &OptimizerContext::new())
3212 .expect("failed to optimize plan")
3213 .data;
3214
3215 assert_optimized_plan_equal!(
3218 optimized_plan,
3219 @r"
3220 Filter: a = Int64(1)
3221 TableScan: test, partial_filters=[a = Int64(1)]
3222 "
3223 )
3224 }
3225
3226 #[test]
3227 fn filter_with_table_provider_unsupported() -> Result<()> {
3228 let plan =
3229 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3230
3231 assert_optimized_plan_equal!(
3232 plan,
3233 @r"
3234 Filter: a = Int64(1)
3235 TableScan: test
3236 "
3237 )
3238 }
3239
3240 #[test]
3241 fn multi_combined_filter() -> Result<()> {
3242 let plan = table_scan_with_pushdown_provider_builder(
3243 TableProviderFilterPushDown::Inexact,
3244 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3245 Some(vec![0]),
3246 )?
3247 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3248 .project(vec![col("a"), col("b")])?
3249 .build()?;
3250
3251 assert_optimized_plan_equal!(
3252 plan,
3253 @r"
3254 Projection: a, b
3255 Filter: a = Int64(10) AND b > Int64(11)
3256 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3257 "
3258 )
3259 }
3260
3261 #[test]
3262 fn multi_combined_two_filters() -> Result<()> {
3263 let plan = table_scan_with_pushdown_provider_builder(
3264 TableProviderFilterPushDown::Inexact,
3265 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3266 Some(vec![0]),
3267 )?
3268 .filter(col("a").eq(lit(10i64)))?
3269 .filter(col("b").gt(lit(11i64)))?
3270 .project(vec![col("a"), col("b")])?
3271 .build()?;
3272
3273 assert_optimized_plan_equal!(
3274 plan,
3275 @r"
3276 Projection: a, b
3277 Filter: a = Int64(10) AND b > Int64(11)
3278 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3279 "
3280 )
3281 }
3282
3283 #[test]
3284 fn multi_combined_filter_exact() -> Result<()> {
3285 let plan = table_scan_with_pushdown_provider_builder(
3286 TableProviderFilterPushDown::Exact,
3287 vec![],
3288 Some(vec![0]),
3289 )?
3290 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3291 .project(vec![col("a"), col("b")])?
3292 .build()?;
3293
3294 assert_optimized_plan_equal!(
3295 plan,
3296 @r"
3297 Projection: a, b
3298 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3299 "
3300 )
3301 }
3302
3303 #[test]
3304 fn multi_combined_two_filters_exact() -> Result<()> {
3305 let plan = table_scan_with_pushdown_provider_builder(
3306 TableProviderFilterPushDown::Exact,
3307 vec![],
3308 Some(vec![0]),
3309 )?
3310 .filter(col("a").eq(lit(10i64)))?
3311 .filter(col("b").gt(lit(11i64)))?
3312 .project(vec![col("a"), col("b")])?
3313 .build()?;
3314
3315 assert_optimized_plan_equal!(
3316 plan,
3317 @r"
3318 Projection: a, b
3319 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3320 "
3321 )
3322 }
3323
3324 #[test]
3325 fn test_filter_with_alias() -> Result<()> {
3326 let table_scan = test_table_scan()?;
3330 let plan = LogicalPlanBuilder::from(table_scan)
3331 .project(vec![col("a").alias("b"), col("c")])?
3332 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3333 .build()?;
3334
3335 assert_snapshot!(plan,
3337 @r"
3338 Filter: b > Int64(10) AND test.c > Int64(10)
3339 Projection: test.a AS b, test.c
3340 TableScan: test
3341 ",
3342 );
3343 assert_optimized_plan_equal!(
3345 plan,
3346 @r"
3347 Projection: test.a AS b, test.c
3348 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3349 "
3350 )
3351 }
3352
3353 #[test]
3354 fn test_filter_with_alias_2() -> Result<()> {
3355 let table_scan = test_table_scan()?;
3359 let plan = LogicalPlanBuilder::from(table_scan)
3360 .project(vec![col("a").alias("b"), col("c")])?
3361 .project(vec![col("b"), col("c")])?
3362 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3363 .build()?;
3364
3365 assert_snapshot!(plan,
3367 @r"
3368 Filter: b > Int64(10) AND test.c > Int64(10)
3369 Projection: b, test.c
3370 Projection: test.a AS b, test.c
3371 TableScan: test
3372 ",
3373 );
3374 assert_optimized_plan_equal!(
3376 plan,
3377 @r"
3378 Projection: b, test.c
3379 Projection: test.a AS b, test.c
3380 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3381 "
3382 )
3383 }
3384
3385 #[test]
3386 fn test_filter_with_multi_alias() -> Result<()> {
3387 let table_scan = test_table_scan()?;
3388 let plan = LogicalPlanBuilder::from(table_scan)
3389 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3390 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3391 .build()?;
3392
3393 assert_snapshot!(plan,
3395 @r"
3396 Filter: b > Int64(10) AND d > Int64(10)
3397 Projection: test.a AS b, test.c AS d
3398 TableScan: test
3399 ",
3400 );
3401 assert_optimized_plan_equal!(
3403 plan,
3404 @r"
3405 Projection: test.a AS b, test.c AS d
3406 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3407 "
3408 )
3409 }
3410
3411 #[test]
3413 fn join_filter_with_alias() -> Result<()> {
3414 let table_scan = test_table_scan()?;
3415 let left = LogicalPlanBuilder::from(table_scan)
3416 .project(vec![col("a").alias("c")])?
3417 .build()?;
3418 let right_table_scan = test_table_scan_with_name("test2")?;
3419 let right = LogicalPlanBuilder::from(right_table_scan)
3420 .project(vec![col("b").alias("d")])?
3421 .build()?;
3422 let filter = col("c").gt(lit(1u32));
3423 let plan = LogicalPlanBuilder::from(left)
3424 .join(
3425 right,
3426 JoinType::Inner,
3427 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3428 Some(filter),
3429 )?
3430 .build()?;
3431
3432 assert_snapshot!(plan,
3433 @r"
3434 Inner Join: c = d Filter: c > UInt32(1)
3435 Projection: test.a AS c
3436 TableScan: test
3437 Projection: test2.b AS d
3438 TableScan: test2
3439 ",
3440 );
3441 assert_optimized_plan_equal!(
3443 plan,
3444 @r"
3445 Inner Join: c = d
3446 Projection: test.a AS c
3447 TableScan: test, full_filters=[test.a > UInt32(1)]
3448 Projection: test2.b AS d
3449 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3450 "
3451 )
3452 }
3453
3454 #[test]
3455 fn test_in_filter_with_alias() -> Result<()> {
3456 let table_scan = test_table_scan()?;
3460 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3461 let plan = LogicalPlanBuilder::from(table_scan)
3462 .project(vec![col("a").alias("b"), col("c")])?
3463 .filter(in_list(col("b"), filter_value, false))?
3464 .build()?;
3465
3466 assert_snapshot!(plan,
3468 @r"
3469 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3470 Projection: test.a AS b, test.c
3471 TableScan: test
3472 ",
3473 );
3474 assert_optimized_plan_equal!(
3476 plan,
3477 @r"
3478 Projection: test.a AS b, test.c
3479 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3480 "
3481 )
3482 }
3483
3484 #[test]
3485 fn test_in_filter_with_alias_2() -> Result<()> {
3486 let table_scan = test_table_scan()?;
3490 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3491 let plan = LogicalPlanBuilder::from(table_scan)
3492 .project(vec![col("a").alias("b"), col("c")])?
3493 .project(vec![col("b"), col("c")])?
3494 .filter(in_list(col("b"), filter_value, false))?
3495 .build()?;
3496
3497 assert_snapshot!(plan,
3499 @r"
3500 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3501 Projection: b, test.c
3502 Projection: test.a AS b, test.c
3503 TableScan: test
3504 ",
3505 );
3506 assert_optimized_plan_equal!(
3508 plan,
3509 @r"
3510 Projection: b, test.c
3511 Projection: test.a AS b, test.c
3512 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3513 "
3514 )
3515 }
3516
3517 #[test]
3518 fn test_in_subquery_with_alias() -> Result<()> {
3519 let table_scan = test_table_scan()?;
3522 let table_scan_sq = test_table_scan_with_name("sq")?;
3523 let subplan = Arc::new(
3524 LogicalPlanBuilder::from(table_scan_sq)
3525 .project(vec![col("c")])?
3526 .build()?,
3527 );
3528 let plan = LogicalPlanBuilder::from(table_scan)
3529 .project(vec![col("a").alias("b"), col("c")])?
3530 .filter(in_subquery(col("b"), subplan))?
3531 .build()?;
3532
3533 assert_snapshot!(plan,
3535 @r"
3536 Filter: b IN (<subquery>)
3537 Subquery:
3538 Projection: sq.c
3539 TableScan: sq
3540 Projection: test.a AS b, test.c
3541 TableScan: test
3542 ",
3543 );
3544 assert_optimized_plan_equal!(
3546 plan,
3547 @r"
3548 Projection: test.a AS b, test.c
3549 TableScan: test, full_filters=[test.a IN (<subquery>)]
3550 Subquery:
3551 Projection: sq.c
3552 TableScan: sq
3553 "
3554 )
3555 }
3556
3557 #[test]
3558 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3559 let plan = LogicalPlanBuilder::empty(true)
3561 .project(vec![lit(0i64).alias("a")])?
3562 .alias("b")?
3563 .project(vec![col("b.a")])?
3564 .alias("b")?
3565 .filter(col("b.a").eq(lit(1i64)))?
3566 .project(vec![col("b.a")])?
3567 .build()?;
3568
3569 assert_snapshot!(plan,
3570 @r"
3571 Projection: b.a
3572 Filter: b.a = Int64(1)
3573 SubqueryAlias: b
3574 Projection: b.a
3575 SubqueryAlias: b
3576 Projection: Int64(0) AS a
3577 EmptyRelation: rows=1
3578 ",
3579 );
3580 assert_optimized_plan_equal!(
3583 plan,
3584 @r"
3585 Projection: b.a
3586 SubqueryAlias: b
3587 Projection: b.a
3588 SubqueryAlias: b
3589 Projection: Int64(0) AS a
3590 Filter: Int64(0) = Int64(1)
3591 EmptyRelation: rows=1
3592 "
3593 )
3594 }
3595
3596 #[test]
3597 fn test_crossjoin_with_or_clause() -> Result<()> {
3598 let table_scan = test_table_scan()?;
3600 let left = LogicalPlanBuilder::from(table_scan)
3601 .project(vec![col("a"), col("b"), col("c")])?
3602 .build()?;
3603 let right_table_scan = test_table_scan_with_name("test1")?;
3604 let right = LogicalPlanBuilder::from(right_table_scan)
3605 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3606 .build()?;
3607 let filter = or(
3608 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3609 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3610 );
3611 let plan = LogicalPlanBuilder::from(left)
3612 .cross_join(right)?
3613 .filter(filter)?
3614 .build()?;
3615
3616 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3617 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3618 Projection: test.a, test.b, test.c
3619 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3620 Projection: test1.a AS d, test1.a AS e
3621 TableScan: test1
3622 ")?;
3623
3624 let optimized_plan = PushDownFilter::new()
3627 .rewrite(plan, &OptimizerContext::new())
3628 .expect("failed to optimize plan")
3629 .data;
3630 assert_optimized_plan_equal!(
3631 optimized_plan,
3632 @r"
3633 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3634 Projection: test.a, test.b, test.c
3635 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3636 Projection: test1.a AS d, test1.a AS e
3637 TableScan: test1
3638 "
3639 )
3640 }
3641
3642 #[test]
3643 fn left_semi_join() -> Result<()> {
3644 let left = test_table_scan_with_name("test1")?;
3645 let right_table_scan = test_table_scan_with_name("test2")?;
3646 let right = LogicalPlanBuilder::from(right_table_scan)
3647 .project(vec![col("a"), col("b")])?
3648 .build()?;
3649 let plan = LogicalPlanBuilder::from(left)
3650 .join(
3651 right,
3652 JoinType::LeftSemi,
3653 (
3654 vec![Column::from_qualified_name("test1.a")],
3655 vec![Column::from_qualified_name("test2.a")],
3656 ),
3657 None,
3658 )?
3659 .filter(col("test2.a").lt_eq(lit(1i64)))?
3660 .build()?;
3661
3662 assert_snapshot!(plan,
3664 @r"
3665 Filter: test2.a <= Int64(1)
3666 LeftSemi Join: test1.a = test2.a
3667 TableScan: test1
3668 Projection: test2.a, test2.b
3669 TableScan: test2
3670 ",
3671 );
3672 assert_optimized_plan_equal!(
3674 plan,
3675 @r"
3676 Filter: test2.a <= Int64(1)
3677 LeftSemi Join: test1.a = test2.a
3678 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3679 Projection: test2.a, test2.b
3680 TableScan: test2
3681 "
3682 )
3683 }
3684
3685 #[test]
3686 fn left_semi_join_with_filters() -> Result<()> {
3687 let left = test_table_scan_with_name("test1")?;
3688 let right_table_scan = test_table_scan_with_name("test2")?;
3689 let right = LogicalPlanBuilder::from(right_table_scan)
3690 .project(vec![col("a"), col("b")])?
3691 .build()?;
3692 let plan = LogicalPlanBuilder::from(left)
3693 .join(
3694 right,
3695 JoinType::LeftSemi,
3696 (
3697 vec![Column::from_qualified_name("test1.a")],
3698 vec![Column::from_qualified_name("test2.a")],
3699 ),
3700 Some(
3701 col("test1.b")
3702 .gt(lit(1u32))
3703 .and(col("test2.b").gt(lit(2u32))),
3704 ),
3705 )?
3706 .build()?;
3707
3708 assert_snapshot!(plan,
3710 @r"
3711 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3712 TableScan: test1
3713 Projection: test2.a, test2.b
3714 TableScan: test2
3715 ",
3716 );
3717 assert_optimized_plan_equal!(
3719 plan,
3720 @r"
3721 LeftSemi Join: test1.a = test2.a
3722 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3723 Projection: test2.a, test2.b
3724 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3725 "
3726 )
3727 }
3728
3729 #[test]
3730 fn right_semi_join() -> Result<()> {
3731 let left = test_table_scan_with_name("test1")?;
3732 let right_table_scan = test_table_scan_with_name("test2")?;
3733 let right = LogicalPlanBuilder::from(right_table_scan)
3734 .project(vec![col("a"), col("b")])?
3735 .build()?;
3736 let plan = LogicalPlanBuilder::from(left)
3737 .join(
3738 right,
3739 JoinType::RightSemi,
3740 (
3741 vec![Column::from_qualified_name("test1.a")],
3742 vec![Column::from_qualified_name("test2.a")],
3743 ),
3744 None,
3745 )?
3746 .filter(col("test1.a").lt_eq(lit(1i64)))?
3747 .build()?;
3748
3749 assert_snapshot!(plan,
3751 @r"
3752 Filter: test1.a <= Int64(1)
3753 RightSemi Join: test1.a = test2.a
3754 TableScan: test1
3755 Projection: test2.a, test2.b
3756 TableScan: test2
3757 ",
3758 );
3759 assert_optimized_plan_equal!(
3761 plan,
3762 @r"
3763 Filter: test1.a <= Int64(1)
3764 RightSemi Join: test1.a = test2.a
3765 TableScan: test1
3766 Projection: test2.a, test2.b
3767 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3768 "
3769 )
3770 }
3771
3772 #[test]
3773 fn right_semi_join_with_filters() -> Result<()> {
3774 let left = test_table_scan_with_name("test1")?;
3775 let right_table_scan = test_table_scan_with_name("test2")?;
3776 let right = LogicalPlanBuilder::from(right_table_scan)
3777 .project(vec![col("a"), col("b")])?
3778 .build()?;
3779 let plan = LogicalPlanBuilder::from(left)
3780 .join(
3781 right,
3782 JoinType::RightSemi,
3783 (
3784 vec![Column::from_qualified_name("test1.a")],
3785 vec![Column::from_qualified_name("test2.a")],
3786 ),
3787 Some(
3788 col("test1.b")
3789 .gt(lit(1u32))
3790 .and(col("test2.b").gt(lit(2u32))),
3791 ),
3792 )?
3793 .build()?;
3794
3795 assert_snapshot!(plan,
3797 @r"
3798 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3799 TableScan: test1
3800 Projection: test2.a, test2.b
3801 TableScan: test2
3802 ",
3803 );
3804 assert_optimized_plan_equal!(
3806 plan,
3807 @r"
3808 RightSemi Join: test1.a = test2.a
3809 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3810 Projection: test2.a, test2.b
3811 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3812 "
3813 )
3814 }
3815
3816 #[test]
3817 fn left_anti_join() -> Result<()> {
3818 let table_scan = test_table_scan_with_name("test1")?;
3819 let left = LogicalPlanBuilder::from(table_scan)
3820 .project(vec![col("a"), col("b")])?
3821 .build()?;
3822 let right_table_scan = test_table_scan_with_name("test2")?;
3823 let right = LogicalPlanBuilder::from(right_table_scan)
3824 .project(vec![col("a"), col("b")])?
3825 .build()?;
3826 let plan = LogicalPlanBuilder::from(left)
3827 .join(
3828 right,
3829 JoinType::LeftAnti,
3830 (
3831 vec![Column::from_qualified_name("test1.a")],
3832 vec![Column::from_qualified_name("test2.a")],
3833 ),
3834 None,
3835 )?
3836 .filter(col("test2.a").gt(lit(2u32)))?
3837 .build()?;
3838
3839 assert_snapshot!(plan,
3841 @r"
3842 Filter: test2.a > UInt32(2)
3843 LeftAnti Join: test1.a = test2.a
3844 Projection: test1.a, test1.b
3845 TableScan: test1
3846 Projection: test2.a, test2.b
3847 TableScan: test2
3848 ",
3849 );
3850 assert_optimized_plan_equal!(
3852 plan,
3853 @r"
3854 Filter: test2.a > UInt32(2)
3855 LeftAnti Join: test1.a = test2.a
3856 Projection: test1.a, test1.b
3857 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3858 Projection: test2.a, test2.b
3859 TableScan: test2
3860 "
3861 )
3862 }
3863
3864 #[test]
3865 fn left_anti_join_with_filters() -> Result<()> {
3866 let table_scan = test_table_scan_with_name("test1")?;
3867 let left = LogicalPlanBuilder::from(table_scan)
3868 .project(vec![col("a"), col("b")])?
3869 .build()?;
3870 let right_table_scan = test_table_scan_with_name("test2")?;
3871 let right = LogicalPlanBuilder::from(right_table_scan)
3872 .project(vec![col("a"), col("b")])?
3873 .build()?;
3874 let plan = LogicalPlanBuilder::from(left)
3875 .join(
3876 right,
3877 JoinType::LeftAnti,
3878 (
3879 vec![Column::from_qualified_name("test1.a")],
3880 vec![Column::from_qualified_name("test2.a")],
3881 ),
3882 Some(
3883 col("test1.b")
3884 .gt(lit(1u32))
3885 .and(col("test2.b").gt(lit(2u32))),
3886 ),
3887 )?
3888 .build()?;
3889
3890 assert_snapshot!(plan,
3892 @r"
3893 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3894 Projection: test1.a, test1.b
3895 TableScan: test1
3896 Projection: test2.a, test2.b
3897 TableScan: test2
3898 ",
3899 );
3900 assert_optimized_plan_equal!(
3902 plan,
3903 @r"
3904 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3905 Projection: test1.a, test1.b
3906 TableScan: test1
3907 Projection: test2.a, test2.b
3908 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3909 "
3910 )
3911 }
3912
3913 #[test]
3914 fn right_anti_join() -> Result<()> {
3915 let table_scan = test_table_scan_with_name("test1")?;
3916 let left = LogicalPlanBuilder::from(table_scan)
3917 .project(vec![col("a"), col("b")])?
3918 .build()?;
3919 let right_table_scan = test_table_scan_with_name("test2")?;
3920 let right = LogicalPlanBuilder::from(right_table_scan)
3921 .project(vec![col("a"), col("b")])?
3922 .build()?;
3923 let plan = LogicalPlanBuilder::from(left)
3924 .join(
3925 right,
3926 JoinType::RightAnti,
3927 (
3928 vec![Column::from_qualified_name("test1.a")],
3929 vec![Column::from_qualified_name("test2.a")],
3930 ),
3931 None,
3932 )?
3933 .filter(col("test1.a").gt(lit(2u32)))?
3934 .build()?;
3935
3936 assert_snapshot!(plan,
3938 @r"
3939 Filter: test1.a > UInt32(2)
3940 RightAnti Join: test1.a = test2.a
3941 Projection: test1.a, test1.b
3942 TableScan: test1
3943 Projection: test2.a, test2.b
3944 TableScan: test2
3945 ",
3946 );
3947 assert_optimized_plan_equal!(
3949 plan,
3950 @r"
3951 Filter: test1.a > UInt32(2)
3952 RightAnti Join: test1.a = test2.a
3953 Projection: test1.a, test1.b
3954 TableScan: test1
3955 Projection: test2.a, test2.b
3956 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3957 "
3958 )
3959 }
3960
3961 #[test]
3962 fn right_anti_join_with_filters() -> Result<()> {
3963 let table_scan = test_table_scan_with_name("test1")?;
3964 let left = LogicalPlanBuilder::from(table_scan)
3965 .project(vec![col("a"), col("b")])?
3966 .build()?;
3967 let right_table_scan = test_table_scan_with_name("test2")?;
3968 let right = LogicalPlanBuilder::from(right_table_scan)
3969 .project(vec![col("a"), col("b")])?
3970 .build()?;
3971 let plan = LogicalPlanBuilder::from(left)
3972 .join(
3973 right,
3974 JoinType::RightAnti,
3975 (
3976 vec![Column::from_qualified_name("test1.a")],
3977 vec![Column::from_qualified_name("test2.a")],
3978 ),
3979 Some(
3980 col("test1.b")
3981 .gt(lit(1u32))
3982 .and(col("test2.b").gt(lit(2u32))),
3983 ),
3984 )?
3985 .build()?;
3986
3987 assert_snapshot!(plan,
3989 @r"
3990 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3991 Projection: test1.a, test1.b
3992 TableScan: test1
3993 Projection: test2.a, test2.b
3994 TableScan: test2
3995 ",
3996 );
3997 assert_optimized_plan_equal!(
3999 plan,
4000 @r"
4001 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
4002 Projection: test1.a, test1.b
4003 TableScan: test1, full_filters=[test1.b > UInt32(1)]
4004 Projection: test2.a, test2.b
4005 TableScan: test2
4006 "
4007 )
4008 }
4009
4010 #[derive(Debug, PartialEq, Eq, Hash)]
4011 struct TestScalarUDF {
4012 signature: Signature,
4013 }
4014
4015 impl ScalarUDFImpl for TestScalarUDF {
4016 fn name(&self) -> &str {
4017 "TestScalarUDF"
4018 }
4019
4020 fn signature(&self) -> &Signature {
4021 &self.signature
4022 }
4023
4024 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4025 Ok(DataType::Int32)
4026 }
4027
4028 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
4029 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
4030 }
4031 }
4032
4033 #[test]
4034 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
4035 let table_scan = test_table_scan_with_name("test1")?;
4037 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4038 signature: Signature::exact(vec![], Volatility::Volatile),
4039 });
4040 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4041
4042 let plan = LogicalPlanBuilder::from(table_scan)
4043 .aggregate(vec![col("a")], vec![sum(col("b"))])?
4044 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
4045 .alias("t")?
4046 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
4047 .project(vec![col("t.a"), col("t.r")])?
4048 .build()?;
4049
4050 assert_snapshot!(plan,
4051 @r"
4052 Projection: t.a, t.r
4053 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
4054 SubqueryAlias: t
4055 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
4056 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
4057 TableScan: test1
4058 ",
4059 );
4060 assert_optimized_plan_equal!(
4061 plan,
4062 @r"
4063 Projection: t.a, t.r
4064 SubqueryAlias: t
4065 Filter: r > Float64(0.5)
4066 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
4067 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
4068 TableScan: test1, full_filters=[test1.a > Int32(5)]
4069 "
4070 )
4071 }
4072
4073 #[test]
4074 fn test_push_down_volatile_function_in_join() -> Result<()> {
4075 let table_scan = test_table_scan_with_name("test1")?;
4077 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4078 signature: Signature::exact(vec![], Volatility::Volatile),
4079 });
4080 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4081 let left = LogicalPlanBuilder::from(table_scan).build()?;
4082 let right_table_scan = test_table_scan_with_name("test2")?;
4083 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4084 let plan = LogicalPlanBuilder::from(left)
4085 .join(
4086 right,
4087 JoinType::Inner,
4088 (
4089 vec![Column::from_qualified_name("test1.a")],
4090 vec![Column::from_qualified_name("test2.a")],
4091 ),
4092 None,
4093 )?
4094 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4095 .alias("t")?
4096 .filter(col("t.r").gt(lit(0.8)))?
4097 .project(vec![col("t.a"), col("t.r")])?
4098 .build()?;
4099
4100 assert_snapshot!(plan,
4101 @r"
4102 Projection: t.a, t.r
4103 Filter: t.r > Float64(0.8)
4104 SubqueryAlias: t
4105 Projection: test1.a AS a, TestScalarUDF() AS r
4106 Inner Join: test1.a = test2.a
4107 TableScan: test1
4108 TableScan: test2
4109 ",
4110 );
4111 assert_optimized_plan_equal!(
4112 plan,
4113 @r"
4114 Projection: t.a, t.r
4115 SubqueryAlias: t
4116 Filter: r > Float64(0.8)
4117 Projection: test1.a AS a, TestScalarUDF() AS r
4118 Inner Join: test1.a = test2.a
4119 TableScan: test1
4120 TableScan: test2
4121 "
4122 )
4123 }
4124
4125 #[test]
4126 fn test_push_down_volatile_table_scan() -> Result<()> {
4127 let table_scan = test_table_scan()?;
4129 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4130 signature: Signature::exact(vec![], Volatility::Volatile),
4131 });
4132 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4133 let plan = LogicalPlanBuilder::from(table_scan)
4134 .project(vec![col("a"), col("b")])?
4135 .filter(expr.gt(lit(0.1)))?
4136 .build()?;
4137
4138 assert_snapshot!(plan,
4139 @r"
4140 Filter: TestScalarUDF() > Float64(0.1)
4141 Projection: test.a, test.b
4142 TableScan: test
4143 ",
4144 );
4145 assert_optimized_plan_equal!(
4146 plan,
4147 @r"
4148 Projection: test.a, test.b
4149 Filter: TestScalarUDF() > Float64(0.1)
4150 TableScan: test
4151 "
4152 )
4153 }
4154
4155 #[test]
4156 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4157 let table_scan = test_table_scan()?;
4159 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4160 signature: Signature::exact(vec![], Volatility::Volatile),
4161 });
4162 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4163 let plan = LogicalPlanBuilder::from(table_scan)
4164 .project(vec![col("a"), col("b")])?
4165 .filter(
4166 expr.gt(lit(0.1))
4167 .and(col("t.a").gt(lit(5)))
4168 .and(col("t.b").gt(lit(10))),
4169 )?
4170 .build()?;
4171
4172 assert_snapshot!(plan,
4173 @r"
4174 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4175 Projection: test.a, test.b
4176 TableScan: test
4177 ",
4178 );
4179 assert_optimized_plan_equal!(
4180 plan,
4181 @r"
4182 Projection: test.a, test.b
4183 Filter: TestScalarUDF() > Float64(0.1)
4184 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4185 "
4186 )
4187 }
4188
4189 #[test]
4190 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4191 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4193 signature: Signature::exact(vec![], Volatility::Volatile),
4194 });
4195 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4196 let plan = table_scan_with_pushdown_provider_builder(
4197 TableProviderFilterPushDown::Unsupported,
4198 vec![],
4199 None,
4200 )?
4201 .project(vec![col("a"), col("b")])?
4202 .filter(
4203 expr.gt(lit(0.1))
4204 .and(col("t.a").gt(lit(5)))
4205 .and(col("t.b").gt(lit(10))),
4206 )?
4207 .build()?;
4208
4209 assert_snapshot!(plan,
4210 @r"
4211 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4212 Projection: a, b
4213 TableScan: test
4214 ",
4215 );
4216 assert_optimized_plan_equal!(
4217 plan,
4218 @r"
4219 Projection: a, b
4220 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4221 TableScan: test
4222 "
4223 )
4224 }
4225
4226 #[test]
4227 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4228 #[derive(Debug, Hash, Eq, PartialEq)]
4230 struct TestUserNode {
4231 schema: DFSchemaRef,
4232 }
4233
4234 impl PartialOrd for TestUserNode {
4235 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4236 None
4237 }
4238 }
4239
4240 impl TestUserNode {
4241 fn new() -> Self {
4242 let schema = Arc::new(
4243 DFSchema::new_with_metadata(
4244 vec![(None, Field::new("a", DataType::Int64, false).into())],
4245 Default::default(),
4246 )
4247 .unwrap(),
4248 );
4249
4250 Self { schema }
4251 }
4252 }
4253
4254 impl UserDefinedLogicalNodeCore for TestUserNode {
4255 fn name(&self) -> &str {
4256 "test_node"
4257 }
4258
4259 fn inputs(&self) -> Vec<&LogicalPlan> {
4260 vec![]
4261 }
4262
4263 fn schema(&self) -> &DFSchemaRef {
4264 &self.schema
4265 }
4266
4267 fn expressions(&self) -> Vec<Expr> {
4268 vec![]
4269 }
4270
4271 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4272 write!(f, "TestUserNode")
4273 }
4274
4275 fn with_exprs_and_inputs(
4276 &self,
4277 exprs: Vec<Expr>,
4278 inputs: Vec<LogicalPlan>,
4279 ) -> Result<Self> {
4280 assert!(exprs.is_empty());
4281 assert!(inputs.is_empty());
4282 Ok(Self {
4283 schema: Arc::clone(&self.schema),
4284 })
4285 }
4286 }
4287
4288 let node = LogicalPlan::Extension(Extension {
4290 node: Arc::new(TestUserNode::new()),
4291 });
4292
4293 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4294
4295 assert_snapshot!(plan,
4297 @r"
4298 Filter: Boolean(false)
4299 TestUserNode
4300 ",
4301 );
4302 assert_optimized_plan_equal!(
4304 plan,
4305 @r"
4306 Filter: Boolean(false)
4307 TestUserNode
4308 "
4309 )
4310 }
4311
4312 #[test]
4318 fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> {
4319 let table_scan = test_table_scan()?;
4320
4321 let proj = LogicalPlanBuilder::from(table_scan)
4323 .project(vec![
4324 leaf_udf_expr(col("a")).alias("val"),
4325 col("b"),
4326 col("c"),
4327 ])?
4328 .build()?;
4329
4330 let plan = LogicalPlanBuilder::from(proj)
4332 .filter(col("val").gt(lit(150i64)))?
4333 .build()?;
4334
4335 assert_optimized_plan_equal!(
4337 plan,
4338 @r"
4339 Filter: val > Int64(150)
4340 Projection: leaf_udf(test.a) AS val, test.b, test.c
4341 TableScan: test
4342 "
4343 )
4344 }
4345
4346 #[test]
4348 fn filter_mixed_predicates_partial_push() -> Result<()> {
4349 let table_scan = test_table_scan()?;
4350
4351 let proj = LogicalPlanBuilder::from(table_scan)
4353 .project(vec![
4354 leaf_udf_expr(col("a")).alias("val"),
4355 col("b"),
4356 col("c"),
4357 ])?
4358 .build()?;
4359
4360 let plan = LogicalPlanBuilder::from(proj)
4362 .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))?
4363 .build()?;
4364
4365 assert_optimized_plan_equal!(
4367 plan,
4368 @r"
4369 Filter: val > Int64(150)
4370 Projection: leaf_udf(test.a) AS val, test.b, test.c
4371 TableScan: test, full_filters=[test.b > Int64(5)]
4372 "
4373 )
4374 }
4375
4376 #[test]
4377 fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> {
4378 let scan = test_table_scan()?;
4379 let scan_with_fetch = match scan {
4380 LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan {
4381 fetch: Some(10),
4382 ..scan
4383 }),
4384 _ => unreachable!(),
4385 };
4386 let plan = LogicalPlanBuilder::from(scan_with_fetch)
4387 .filter(col("a").gt(lit(10i64)))?
4388 .build()?;
4389 assert_optimized_plan_equal!(
4391 plan,
4392 @r"
4393 Filter: test.a > Int64(10)
4394 TableScan: test, fetch=10
4395 "
4396 )
4397 }
4398
4399 #[test]
4400 fn filter_push_down_through_sort_without_fetch() -> Result<()> {
4401 let table_scan = test_table_scan()?;
4402 let plan = LogicalPlanBuilder::from(table_scan)
4403 .sort(vec![col("a").sort(true, true)])?
4404 .filter(col("a").gt(lit(10i64)))?
4405 .build()?;
4406 assert_optimized_plan_equal!(
4408 plan,
4409 @r"
4410 Sort: test.a ASC NULLS FIRST
4411 TableScan: test, full_filters=[test.a > Int64(10)]
4412 "
4413 )
4414 }
4415
4416 #[test]
4417 fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> {
4418 let table_scan = test_table_scan()?;
4419 let plan = LogicalPlanBuilder::from(table_scan)
4420 .sort_with_limit(vec![col("a").sort(true, true)], Some(5))?
4421 .filter(col("a").gt(lit(10i64)))?
4422 .build()?;
4423 assert_optimized_plan_equal!(
4426 plan,
4427 @r"
4428 Filter: test.a > Int64(10)
4429 Sort: test.a ASC NULLS FIRST, fetch=5
4430 TableScan: test
4431 "
4432 )
4433 }
4434}