1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31 Column, DFSchema, Result, assert_eq_or_internal_err, assert_or_internal_err,
32 internal_err, plan_err, qualified_name,
33};
34use datafusion_expr::expr::WindowFunction;
35use datafusion_expr::expr_rewriter::replace_col;
36use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
37use datafusion_expr::utils::{
38 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
39};
40use datafusion_expr::{
41 BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, and, or,
42};
43
44use crate::optimizer::ApplyOrder;
45use crate::simplify_expressions::simplify_predicates;
46use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
47use crate::{OptimizerConfig, OptimizerRule};
48
49#[derive(Default, Debug)]
137pub struct PushDownFilter {}
138
139pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
164 match join_type {
165 JoinType::Inner => (true, true),
166 JoinType::Left => (true, false),
167 JoinType::Right => (false, true),
168 JoinType::Full => (false, false),
169 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
172 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
175 }
176}
177
178pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
188 match join_type {
189 JoinType::Inner => (true, true),
190 JoinType::Left => (false, true),
191 JoinType::Right => (true, false),
192 JoinType::Full => (false, false),
193 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
194 JoinType::LeftAnti => (false, true),
195 JoinType::RightAnti => (true, false),
196 JoinType::LeftMark => (false, true),
197 JoinType::RightMark => (true, false),
198 }
199}
200
201#[derive(Debug)]
204struct ColumnChecker<'a> {
205 left_schema: &'a DFSchema,
207 left_columns: Option<HashSet<Column>>,
209 right_schema: &'a DFSchema,
211 right_columns: Option<HashSet<Column>>,
213}
214
215impl<'a> ColumnChecker<'a> {
216 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
217 Self {
218 left_schema,
219 left_columns: None,
220 right_schema,
221 right_columns: None,
222 }
223 }
224
225 fn is_left_only(&mut self, predicate: &Expr) -> bool {
227 if self.left_columns.is_none() {
228 self.left_columns = Some(schema_columns(self.left_schema));
229 }
230 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
231 }
232
233 fn is_right_only(&mut self, predicate: &Expr) -> bool {
235 if self.right_columns.is_none() {
236 self.right_columns = Some(schema_columns(self.right_schema));
237 }
238 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
239 }
240}
241
242fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
244 schema
245 .iter()
246 .flat_map(|(qualifier, field)| {
247 [
248 Column::new(qualifier.cloned(), field.name()),
249 Column::new_unqualified(field.name()),
251 ]
252 })
253 .collect::<HashSet<_>>()
254}
255
256fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
258 let mut is_evaluate = true;
259 predicate.apply(|expr| match expr {
260 Expr::Column(_)
261 | Expr::Literal(_, _)
262 | Expr::Placeholder(_)
263 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
264 Expr::Exists { .. }
265 | Expr::InSubquery(_)
266 | Expr::ScalarSubquery(_)
267 | Expr::OuterReferenceColumn(_, _)
268 | Expr::Unnest(_) => {
269 is_evaluate = false;
270 Ok(TreeNodeRecursion::Stop)
271 }
272 Expr::Alias(_)
273 | Expr::BinaryExpr(_)
274 | Expr::Like(_)
275 | Expr::SimilarTo(_)
276 | Expr::Not(_)
277 | Expr::IsNotNull(_)
278 | Expr::IsNull(_)
279 | Expr::IsTrue(_)
280 | Expr::IsFalse(_)
281 | Expr::IsUnknown(_)
282 | Expr::IsNotTrue(_)
283 | Expr::IsNotFalse(_)
284 | Expr::IsNotUnknown(_)
285 | Expr::Negative(_)
286 | Expr::Between(_)
287 | Expr::Case(_)
288 | Expr::Cast(_)
289 | Expr::TryCast(_)
290 | Expr::InList { .. }
291 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
292 #[expect(deprecated)]
294 Expr::AggregateFunction(_)
295 | Expr::WindowFunction(_)
296 | Expr::Wildcard { .. }
297 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
298 })?;
299 Ok(is_evaluate)
300}
301
302fn extract_or_clauses_for_join<'a>(
336 filters: &'a [Expr],
337 schema: &'a DFSchema,
338) -> impl Iterator<Item = Expr> + 'a {
339 let schema_columns = schema_columns(schema);
340
341 filters.iter().filter_map(move |expr| {
343 if let Expr::BinaryExpr(BinaryExpr {
344 left,
345 op: Operator::Or,
346 right,
347 }) = expr
348 {
349 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
350 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
351
352 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
354 return Some(or(left_expr, right_expr));
355 }
356 }
357 None
358 })
359}
360
361fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
373 let mut predicate = None;
374
375 match expr {
376 Expr::BinaryExpr(BinaryExpr {
377 left: l_expr,
378 op: Operator::Or,
379 right: r_expr,
380 }) => {
381 let l_expr = extract_or_clause(l_expr, schema_columns);
382 let r_expr = extract_or_clause(r_expr, schema_columns);
383
384 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
385 predicate = Some(or(l_expr, r_expr));
386 }
387 }
388 Expr::BinaryExpr(BinaryExpr {
389 left: l_expr,
390 op: Operator::And,
391 right: r_expr,
392 }) => {
393 let l_expr = extract_or_clause(l_expr, schema_columns);
394 let r_expr = extract_or_clause(r_expr, schema_columns);
395
396 match (l_expr, r_expr) {
397 (Some(l_expr), Some(r_expr)) => {
398 predicate = Some(and(l_expr, r_expr));
399 }
400 (Some(l_expr), None) => {
401 predicate = Some(l_expr);
402 }
403 (None, Some(r_expr)) => {
404 predicate = Some(r_expr);
405 }
406 (None, None) => {
407 predicate = None;
408 }
409 }
410 }
411 _ => {
412 if has_all_column_refs(expr, schema_columns) {
413 predicate = Some(expr.clone());
414 }
415 }
416 }
417
418 predicate
419}
420
421fn push_down_all_join(
423 predicates: Vec<Expr>,
424 inferred_join_predicates: Vec<Expr>,
425 mut join: Join,
426 on_filter: Vec<Expr>,
427) -> Result<Transformed<LogicalPlan>> {
428 let is_inner_join = join.join_type == JoinType::Inner;
429 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
431
432 let left_schema = join.left.schema();
437 let right_schema = join.right.schema();
438 let mut left_push = vec![];
439 let mut right_push = vec![];
440 let mut keep_predicates = vec![];
441 let mut join_conditions = vec![];
442 let mut checker = ColumnChecker::new(left_schema, right_schema);
443 for predicate in predicates {
444 if left_preserved && checker.is_left_only(&predicate) {
445 left_push.push(predicate);
446 } else if right_preserved && checker.is_right_only(&predicate) {
447 right_push.push(predicate);
448 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
449 join_conditions.push(predicate);
452 } else {
453 keep_predicates.push(predicate);
454 }
455 }
456
457 for predicate in inferred_join_predicates {
459 if left_preserved && checker.is_left_only(&predicate) {
460 left_push.push(predicate);
461 } else if right_preserved && checker.is_right_only(&predicate) {
462 right_push.push(predicate);
463 }
464 }
465
466 let mut on_filter_join_conditions = vec![];
467 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
468
469 if !on_filter.is_empty() {
470 for on in on_filter {
471 if on_left_preserved && checker.is_left_only(&on) {
472 left_push.push(on)
473 } else if on_right_preserved && checker.is_right_only(&on) {
474 right_push.push(on)
475 } else {
476 on_filter_join_conditions.push(on)
477 }
478 }
479 }
480
481 if left_preserved {
484 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
485 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
486 }
487 if right_preserved {
488 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
489 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
490 }
491
492 if on_left_preserved {
495 left_push.extend(extract_or_clauses_for_join(
496 &on_filter_join_conditions,
497 left_schema,
498 ));
499 }
500 if on_right_preserved {
501 right_push.extend(extract_or_clauses_for_join(
502 &on_filter_join_conditions,
503 right_schema,
504 ));
505 }
506
507 if let Some(predicate) = conjunction(left_push) {
508 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
509 }
510 if let Some(predicate) = conjunction(right_push) {
511 join.right =
512 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
513 }
514
515 join_conditions.extend(on_filter_join_conditions);
517 join.filter = conjunction(join_conditions);
518
519 let plan = LogicalPlan::Join(join);
521 let plan = if let Some(predicate) = conjunction(keep_predicates) {
522 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
523 } else {
524 plan
525 };
526 Ok(Transformed::yes(plan))
527}
528
529fn push_down_join(
530 join: Join,
531 parent_predicate: Option<&Expr>,
532) -> Result<Transformed<LogicalPlan>> {
533 let predicates = parent_predicate
535 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
536
537 let on_filters = join
539 .filter
540 .as_ref()
541 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
542
543 let inferred_join_predicates =
545 infer_join_predicates(&join, &predicates, &on_filters)?;
546
547 if on_filters.is_empty()
548 && predicates.is_empty()
549 && inferred_join_predicates.is_empty()
550 {
551 return Ok(Transformed::no(LogicalPlan::Join(join)));
552 }
553
554 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
555}
556
557fn infer_join_predicates(
567 join: &Join,
568 predicates: &[Expr],
569 on_filters: &[Expr],
570) -> Result<Vec<Expr>> {
571 let join_col_keys = join
573 .on
574 .iter()
575 .filter_map(|(l, r)| {
576 let left_col = l.try_as_col()?;
577 let right_col = r.try_as_col()?;
578 Some((left_col, right_col))
579 })
580 .collect::<Vec<_>>();
581
582 let join_type = join.join_type;
583
584 let mut inferred_predicates = InferredPredicates::new(join_type);
585
586 infer_join_predicates_from_predicates(
587 &join_col_keys,
588 predicates,
589 &mut inferred_predicates,
590 )?;
591
592 infer_join_predicates_from_on_filters(
593 &join_col_keys,
594 join_type,
595 on_filters,
596 &mut inferred_predicates,
597 )?;
598
599 Ok(inferred_predicates.predicates)
600}
601
602struct InferredPredicates {
611 predicates: Vec<Expr>,
612 is_inner_join: bool,
613}
614
615impl InferredPredicates {
616 fn new(join_type: JoinType) -> Self {
617 Self {
618 predicates: vec![],
619 is_inner_join: matches!(join_type, JoinType::Inner),
620 }
621 }
622
623 fn try_build_predicate(
624 &mut self,
625 predicate: Expr,
626 replace_map: &HashMap<&Column, &Column>,
627 ) -> Result<()> {
628 if self.is_inner_join
629 || matches!(
630 is_restrict_null_predicate(
631 predicate.clone(),
632 replace_map.keys().cloned()
633 ),
634 Ok(true)
635 )
636 {
637 self.predicates.push(replace_col(predicate, replace_map)?);
638 }
639
640 Ok(())
641 }
642}
643
644fn infer_join_predicates_from_predicates(
653 join_col_keys: &[(&Column, &Column)],
654 predicates: &[Expr],
655 inferred_predicates: &mut InferredPredicates,
656) -> Result<()> {
657 infer_join_predicates_impl::<true, true>(
658 join_col_keys,
659 predicates,
660 inferred_predicates,
661 )
662}
663
664fn infer_join_predicates_from_on_filters(
676 join_col_keys: &[(&Column, &Column)],
677 join_type: JoinType,
678 on_filters: &[Expr],
679 inferred_predicates: &mut InferredPredicates,
680) -> Result<()> {
681 match join_type {
682 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
683 JoinType::Inner => infer_join_predicates_impl::<true, true>(
684 join_col_keys,
685 on_filters,
686 inferred_predicates,
687 ),
688 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
689 infer_join_predicates_impl::<true, false>(
690 join_col_keys,
691 on_filters,
692 inferred_predicates,
693 )
694 }
695 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
696 infer_join_predicates_impl::<false, true>(
697 join_col_keys,
698 on_filters,
699 inferred_predicates,
700 )
701 }
702 }
703}
704
705fn infer_join_predicates_impl<
721 const ENABLE_LEFT_TO_RIGHT: bool,
722 const ENABLE_RIGHT_TO_LEFT: bool,
723>(
724 join_col_keys: &[(&Column, &Column)],
725 input_predicates: &[Expr],
726 inferred_predicates: &mut InferredPredicates,
727) -> Result<()> {
728 for predicate in input_predicates {
729 let mut join_cols_to_replace = HashMap::new();
730
731 for &col in &predicate.column_refs() {
732 for (l, r) in join_col_keys.iter() {
733 if ENABLE_LEFT_TO_RIGHT && col == *l {
734 join_cols_to_replace.insert(col, *r);
735 break;
736 }
737 if ENABLE_RIGHT_TO_LEFT && col == *r {
738 join_cols_to_replace.insert(col, *l);
739 break;
740 }
741 }
742 }
743 if join_cols_to_replace.is_empty() {
744 continue;
745 }
746
747 inferred_predicates
748 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
749 }
750 Ok(())
751}
752
753impl OptimizerRule for PushDownFilter {
754 fn name(&self) -> &str {
755 "push_down_filter"
756 }
757
758 fn apply_order(&self) -> Option<ApplyOrder> {
759 Some(ApplyOrder::TopDown)
760 }
761
762 fn supports_rewrite(&self) -> bool {
763 true
764 }
765
766 fn rewrite(
767 &self,
768 plan: LogicalPlan,
769 config: &dyn OptimizerConfig,
770 ) -> Result<Transformed<LogicalPlan>> {
771 let _ = config.options();
772 if let LogicalPlan::Join(join) = plan {
773 return push_down_join(join, None);
774 };
775
776 let plan_schema = Arc::clone(plan.schema());
777
778 let LogicalPlan::Filter(mut filter) = plan else {
779 return Ok(Transformed::no(plan));
780 };
781
782 let predicate = split_conjunction_owned(filter.predicate.clone());
783 let old_predicate_len = predicate.len();
784 let new_predicates = simplify_predicates(predicate)?;
785 if old_predicate_len != new_predicates.len() {
786 let Some(new_predicate) = conjunction(new_predicates) else {
787 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
790 };
791 filter.predicate = new_predicate;
792 }
793
794 if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() {
798 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
799 }
800
801 match Arc::unwrap_or_clone(filter.input) {
802 LogicalPlan::Filter(child_filter) => {
803 let parents_predicates = split_conjunction_owned(filter.predicate);
804
805 let child_predicates = split_conjunction_owned(child_filter.predicate);
807 let new_predicates = parents_predicates
808 .into_iter()
809 .chain(child_predicates)
810 .collect::<IndexSet<_>>()
812 .into_iter()
813 .collect::<Vec<_>>();
814
815 let Some(new_predicate) = conjunction(new_predicates) else {
816 return plan_err!("at least one expression exists");
817 };
818 let new_filter = LogicalPlan::Filter(Filter::try_new(
819 new_predicate,
820 child_filter.input,
821 )?);
822 self.rewrite(new_filter, config)
823 }
824 LogicalPlan::Repartition(repartition) => {
825 let new_filter =
826 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
827 .map(LogicalPlan::Filter)?;
828 insert_below(LogicalPlan::Repartition(repartition), new_filter)
829 }
830 LogicalPlan::Distinct(distinct) => {
831 let new_filter =
832 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
833 .map(LogicalPlan::Filter)?;
834 insert_below(LogicalPlan::Distinct(distinct), new_filter)
835 }
836 LogicalPlan::Sort(sort) => {
837 let new_filter =
838 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
839 .map(LogicalPlan::Filter)?;
840 insert_below(LogicalPlan::Sort(sort), new_filter)
841 }
842 LogicalPlan::SubqueryAlias(subquery_alias) => {
843 let mut replace_map = HashMap::new();
844 for (i, (qualifier, field)) in
845 subquery_alias.input.schema().iter().enumerate()
846 {
847 let (sub_qualifier, sub_field) =
848 subquery_alias.schema.qualified_field(i);
849 replace_map.insert(
850 qualified_name(sub_qualifier, sub_field.name()),
851 Expr::Column(Column::new(qualifier.cloned(), field.name())),
852 );
853 }
854 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
855
856 let new_filter = LogicalPlan::Filter(Filter::try_new(
857 new_predicate,
858 Arc::clone(&subquery_alias.input),
859 )?);
860 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
861 }
862 LogicalPlan::Projection(projection) => {
863 let predicates = split_conjunction_owned(filter.predicate.clone());
864 let (new_projection, keep_predicate) =
865 rewrite_projection(predicates, projection)?;
866 if new_projection.transformed {
867 match keep_predicate {
868 None => Ok(new_projection),
869 Some(keep_predicate) => new_projection.map_data(|child_plan| {
870 Filter::try_new(keep_predicate, Arc::new(child_plan))
871 .map(LogicalPlan::Filter)
872 }),
873 }
874 } else {
875 filter.input = Arc::new(new_projection.data);
876 Ok(Transformed::no(LogicalPlan::Filter(filter)))
877 }
878 }
879 LogicalPlan::Unnest(mut unnest) => {
880 let predicates = split_conjunction_owned(filter.predicate.clone());
881 let mut non_unnest_predicates = vec![];
882 let mut unnest_predicates = vec![];
883 let mut unnest_struct_columns = vec![];
884
885 for idx in &unnest.struct_type_columns {
886 let (sub_qualifier, field) =
887 unnest.input.schema().qualified_field(*idx);
888 let field_name = field.name().clone();
889
890 if let DataType::Struct(children) = field.data_type() {
891 for child in children {
892 let child_name = child.name().clone();
893 unnest_struct_columns.push(Column::new(
894 sub_qualifier.cloned(),
895 format!("{field_name}.{child_name}"),
896 ));
897 }
898 }
899 }
900
901 for predicate in predicates {
902 let mut accum: HashSet<Column> = HashSet::new();
904 expr_to_columns(&predicate, &mut accum)?;
905
906 let contains_list_columns =
907 unnest.list_type_columns.iter().any(|(_, unnest_list)| {
908 accum.contains(&unnest_list.output_column)
909 });
910 let contains_struct_columns =
911 unnest_struct_columns.iter().any(|c| accum.contains(c));
912
913 if contains_list_columns || contains_struct_columns {
914 unnest_predicates.push(predicate);
915 } else {
916 non_unnest_predicates.push(predicate);
917 }
918 }
919
920 if non_unnest_predicates.is_empty() {
923 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
924 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
925 }
926
927 let unnest_input = std::mem::take(&mut unnest.input);
936
937 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
938 conjunction(non_unnest_predicates).unwrap(), unnest_input,
940 )?);
941
942 let unnest_plan =
946 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
947
948 match conjunction(unnest_predicates) {
949 None => Ok(unnest_plan),
950 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
951 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
952 ))),
953 }
954 }
955 LogicalPlan::Union(ref union) => {
956 let mut inputs = Vec::with_capacity(union.inputs.len());
957 for input in &union.inputs {
958 let mut replace_map = HashMap::new();
959 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
960 let (union_qualifier, union_field) =
961 union.schema.qualified_field(i);
962 replace_map.insert(
963 qualified_name(union_qualifier, union_field.name()),
964 Expr::Column(Column::new(qualifier.cloned(), field.name())),
965 );
966 }
967
968 let push_predicate =
969 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
970 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
971 push_predicate,
972 Arc::clone(input),
973 )?)))
974 }
975 Ok(Transformed::yes(LogicalPlan::Union(Union {
976 inputs,
977 schema: Arc::clone(&plan_schema),
978 })))
979 }
980 LogicalPlan::Aggregate(agg) => {
981 let group_expr_columns = agg
983 .group_expr
984 .iter()
985 .map(|e| {
986 let (relation, name) = e.qualified_name();
987 Column::new(relation, name)
988 })
989 .collect::<HashSet<_>>();
990
991 let predicates = split_conjunction_owned(filter.predicate);
992
993 let mut keep_predicates = vec![];
994 let mut push_predicates = vec![];
995 for expr in predicates {
996 let cols = expr.column_refs();
997 if cols.iter().all(|c| group_expr_columns.contains(c)) {
998 push_predicates.push(expr);
999 } else {
1000 keep_predicates.push(expr);
1001 }
1002 }
1003
1004 let mut replace_map = HashMap::new();
1008 for expr in &agg.group_expr {
1009 replace_map.insert(expr.schema_name().to_string(), expr.clone());
1010 }
1011 let replaced_push_predicates = push_predicates
1012 .into_iter()
1013 .map(|expr| replace_cols_by_name(expr, &replace_map))
1014 .collect::<Result<Vec<_>>>()?;
1015
1016 let agg_input = Arc::clone(&agg.input);
1017 Transformed::yes(LogicalPlan::Aggregate(agg))
1018 .transform_data(|new_plan| {
1019 if let Some(predicate) = conjunction(replaced_push_predicates) {
1021 let new_filter = make_filter(predicate, agg_input)?;
1022 insert_below(new_plan, new_filter)
1023 } else {
1024 Ok(Transformed::no(new_plan))
1025 }
1026 })?
1027 .map_data(|child_plan| {
1028 if let Some(predicate) = conjunction(keep_predicates) {
1031 make_filter(predicate, Arc::new(child_plan))
1032 } else {
1033 Ok(child_plan)
1034 }
1035 })
1036 }
1037 LogicalPlan::Window(window) => {
1048 let extract_partition_keys = |func: &WindowFunction| {
1054 func.params
1055 .partition_by
1056 .iter()
1057 .map(|c| {
1058 let (relation, name) = c.qualified_name();
1059 Column::new(relation, name)
1060 })
1061 .collect::<HashSet<_>>()
1062 };
1063 let potential_partition_keys = window
1064 .window_expr
1065 .iter()
1066 .map(|e| {
1067 match e {
1068 Expr::WindowFunction(window_func) => {
1069 extract_partition_keys(window_func)
1070 }
1071 Expr::Alias(alias) => {
1072 if let Expr::WindowFunction(window_func) =
1073 alias.expr.as_ref()
1074 {
1075 extract_partition_keys(window_func)
1076 } else {
1077 unreachable!()
1079 }
1080 }
1081 _ => {
1082 unreachable!()
1084 }
1085 }
1086 })
1087 .reduce(|a, b| &a & &b)
1090 .unwrap_or_default();
1091
1092 let predicates = split_conjunction_owned(filter.predicate);
1093 let mut keep_predicates = vec![];
1094 let mut push_predicates = vec![];
1095 for expr in predicates {
1096 let cols = expr.column_refs();
1097 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1098 push_predicates.push(expr);
1099 } else {
1100 keep_predicates.push(expr);
1101 }
1102 }
1103
1104 let window_input = Arc::clone(&window.input);
1113 Transformed::yes(LogicalPlan::Window(window))
1114 .transform_data(|new_plan| {
1115 if let Some(predicate) = conjunction(push_predicates) {
1117 let new_filter = make_filter(predicate, window_input)?;
1118 insert_below(new_plan, new_filter)
1119 } else {
1120 Ok(Transformed::no(new_plan))
1121 }
1122 })?
1123 .map_data(|child_plan| {
1124 if let Some(predicate) = conjunction(keep_predicates) {
1127 make_filter(predicate, Arc::new(child_plan))
1128 } else {
1129 Ok(child_plan)
1130 }
1131 })
1132 }
1133 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1134 LogicalPlan::TableScan(scan) => {
1135 let filter_predicates = split_conjunction(&filter.predicate);
1136
1137 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1138 filter_predicates
1139 .into_iter()
1140 .partition(|pred| pred.is_volatile());
1141
1142 let supported_filters = scan
1144 .source
1145 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1146 assert_eq_or_internal_err!(
1147 non_volatile_filters.len(),
1148 supported_filters.len(),
1149 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1150 supported_filters.len(),
1151 non_volatile_filters.len()
1152 );
1153
1154 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1156
1157 let new_scan_filters = zip
1158 .clone()
1159 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1160 .map(|(pred, _)| pred);
1161
1162 let new_scan_filters: Vec<Expr> = scan
1164 .filters
1165 .iter()
1166 .chain(new_scan_filters)
1167 .unique()
1168 .cloned()
1169 .collect();
1170
1171 let new_predicate: Vec<Expr> = zip
1173 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1174 .map(|(pred, _)| pred)
1175 .chain(volatile_filters)
1176 .cloned()
1177 .collect();
1178
1179 let new_scan = LogicalPlan::TableScan(TableScan {
1180 filters: new_scan_filters,
1181 ..scan
1182 });
1183
1184 Transformed::yes(new_scan).transform_data(|new_scan| {
1185 if let Some(predicate) = conjunction(new_predicate) {
1186 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1187 } else {
1188 Ok(Transformed::no(new_scan))
1189 }
1190 })
1191 }
1192 LogicalPlan::Extension(extension_plan) => {
1193 if extension_plan.node.inputs().is_empty() {
1196 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1197 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1198 }
1199 let prevent_cols =
1200 extension_plan.node.prevent_predicate_push_down_columns();
1201
1202 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1206 .iter()
1207 .map(|expr| {
1208 let cols = expr.column_refs();
1209 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1210 Ok(false) } else {
1212 Ok(true) }
1214 })
1215 .collect::<Result<Vec<_>>>()?;
1216
1217 if predicate_push_or_keep.iter().all(|&x| !x) {
1219 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1220 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1221 }
1222
1223 let mut keep_predicates = vec![];
1225 let mut push_predicates = vec![];
1226 for (push, expr) in predicate_push_or_keep
1227 .into_iter()
1228 .zip(split_conjunction_owned(filter.predicate).into_iter())
1229 {
1230 if !push {
1231 keep_predicates.push(expr);
1232 } else {
1233 push_predicates.push(expr);
1234 }
1235 }
1236
1237 let new_children = match conjunction(push_predicates) {
1238 Some(predicate) => extension_plan
1239 .node
1240 .inputs()
1241 .into_iter()
1242 .map(|child| {
1243 Ok(LogicalPlan::Filter(Filter::try_new(
1244 predicate.clone(),
1245 Arc::new(child.clone()),
1246 )?))
1247 })
1248 .collect::<Result<Vec<_>>>()?,
1249 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1250 };
1251 let child_plan = LogicalPlan::Extension(extension_plan);
1253 let new_extension =
1254 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1255
1256 let new_plan = match conjunction(keep_predicates) {
1257 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1258 predicate,
1259 Arc::new(new_extension),
1260 )?),
1261 None => new_extension,
1262 };
1263 Ok(Transformed::yes(new_plan))
1264 }
1265 child => {
1266 filter.input = Arc::new(child);
1267 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1268 }
1269 }
1270 }
1271}
1272
1273fn rewrite_projection(
1301 predicates: Vec<Expr>,
1302 mut projection: Projection,
1303) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1304 let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1308 .schema
1309 .iter()
1310 .zip(projection.expr.iter())
1311 .map(|((qualifier, field), expr)| {
1312 let expr = expr.clone().unalias();
1314
1315 (qualified_name(qualifier, field.name()), expr)
1316 })
1317 .partition(|(_, value)| value.is_volatile());
1318
1319 let mut push_predicates = vec![];
1320 let mut keep_predicates = vec![];
1321 for expr in predicates {
1322 if contain(&expr, &volatile_map) {
1323 keep_predicates.push(expr);
1324 } else {
1325 push_predicates.push(expr);
1326 }
1327 }
1328
1329 match conjunction(push_predicates) {
1330 Some(expr) => {
1331 let new_filter = LogicalPlan::Filter(Filter::try_new(
1334 replace_cols_by_name(expr, &non_volatile_map)?,
1335 std::mem::take(&mut projection.input),
1336 )?);
1337
1338 projection.input = Arc::new(new_filter);
1339
1340 Ok((
1341 Transformed::yes(LogicalPlan::Projection(projection)),
1342 conjunction(keep_predicates),
1343 ))
1344 }
1345 None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1346 }
1347}
1348
1349pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1351 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1352}
1353
1354fn insert_below(
1368 plan: LogicalPlan,
1369 new_child: LogicalPlan,
1370) -> Result<Transformed<LogicalPlan>> {
1371 let mut new_child = Some(new_child);
1372 let transformed_plan = plan.map_children(|_child| {
1373 if let Some(new_child) = new_child.take() {
1374 Ok(Transformed::yes(new_child))
1375 } else {
1376 internal_err!("node had more than one input")
1378 }
1379 })?;
1380
1381 assert_or_internal_err!(new_child.is_none(), "node had no inputs");
1383
1384 Ok(transformed_plan)
1385}
1386
1387impl PushDownFilter {
1388 #[expect(missing_docs)]
1389 pub fn new() -> Self {
1390 Self {}
1391 }
1392}
1393
1394pub fn replace_cols_by_name(
1396 e: Expr,
1397 replace_map: &HashMap<String, Expr>,
1398) -> Result<Expr> {
1399 e.transform_up(|expr| {
1400 Ok(if let Expr::Column(c) = &expr {
1401 match replace_map.get(&c.flat_name()) {
1402 Some(new_c) => Transformed::yes(new_c.clone()),
1403 None => Transformed::no(expr),
1404 }
1405 } else {
1406 Transformed::no(expr)
1407 })
1408 })
1409 .data()
1410}
1411
1412fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1414 let mut is_contain = false;
1415 e.apply(|expr| {
1416 Ok(if let Expr::Column(c) = &expr {
1417 match check_map.get(&c.flat_name()) {
1418 Some(_) => {
1419 is_contain = true;
1420 TreeNodeRecursion::Stop
1421 }
1422 None => TreeNodeRecursion::Continue,
1423 }
1424 } else {
1425 TreeNodeRecursion::Continue
1426 })
1427 })
1428 .unwrap();
1429 is_contain
1430}
1431
1432#[cfg(test)]
1433mod tests {
1434 use std::any::Any;
1435 use std::cmp::Ordering;
1436 use std::fmt::{Debug, Formatter};
1437
1438 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1439 use async_trait::async_trait;
1440
1441 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1442 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1443 use datafusion_expr::logical_plan::table_scan;
1444 use datafusion_expr::{
1445 ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder,
1446 ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
1447 UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, in_list,
1448 in_subquery, lit,
1449 };
1450
1451 use crate::OptimizerContext;
1452 use crate::assert_optimized_plan_eq_snapshot;
1453 use crate::optimizer::Optimizer;
1454 use crate::simplify_expressions::SimplifyExpressions;
1455 use crate::test::*;
1456 use datafusion_expr::test::function_stub::sum;
1457 use insta::assert_snapshot;
1458
1459 use super::*;
1460
1461 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1462
1463 macro_rules! assert_optimized_plan_equal {
1464 (
1465 $plan:expr,
1466 @ $expected:literal $(,)?
1467 ) => {{
1468 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1469 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1470 assert_optimized_plan_eq_snapshot!(
1471 optimizer_ctx,
1472 rules,
1473 $plan,
1474 @ $expected,
1475 )
1476 }};
1477 }
1478
1479 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1480 (
1481 $plan:expr,
1482 @ $expected:literal $(,)?
1483 ) => {{
1484 let optimizer = Optimizer::with_rules(vec![
1485 Arc::new(SimplifyExpressions::new()),
1486 Arc::new(PushDownFilter::new()),
1487 ]);
1488 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1489 assert_snapshot!(optimized_plan, @ $expected);
1490 Ok::<(), DataFusionError>(())
1491 }};
1492 }
1493
1494 #[test]
1495 fn filter_before_projection() -> Result<()> {
1496 let table_scan = test_table_scan()?;
1497 let plan = LogicalPlanBuilder::from(table_scan)
1498 .project(vec![col("a"), col("b")])?
1499 .filter(col("a").eq(lit(1i64)))?
1500 .build()?;
1501 assert_optimized_plan_equal!(
1503 plan,
1504 @r"
1505 Projection: test.a, test.b
1506 TableScan: test, full_filters=[test.a = Int64(1)]
1507 "
1508 )
1509 }
1510
1511 #[test]
1512 fn filter_after_limit() -> Result<()> {
1513 let table_scan = test_table_scan()?;
1514 let plan = LogicalPlanBuilder::from(table_scan)
1515 .project(vec![col("a"), col("b")])?
1516 .limit(0, Some(10))?
1517 .filter(col("a").eq(lit(1i64)))?
1518 .build()?;
1519 assert_optimized_plan_equal!(
1521 plan,
1522 @r"
1523 Filter: test.a = Int64(1)
1524 Limit: skip=0, fetch=10
1525 Projection: test.a, test.b
1526 TableScan: test
1527 "
1528 )
1529 }
1530
1531 #[test]
1532 fn filter_no_columns() -> Result<()> {
1533 let table_scan = test_table_scan()?;
1534 let plan = LogicalPlanBuilder::from(table_scan)
1535 .filter(lit(0i64).eq(lit(1i64)))?
1536 .build()?;
1537 assert_optimized_plan_equal!(
1538 plan,
1539 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1540 )
1541 }
1542
1543 #[test]
1544 fn filter_jump_2_plans() -> Result<()> {
1545 let table_scan = test_table_scan()?;
1546 let plan = LogicalPlanBuilder::from(table_scan)
1547 .project(vec![col("a"), col("b"), col("c")])?
1548 .project(vec![col("c"), col("b")])?
1549 .filter(col("a").eq(lit(1i64)))?
1550 .build()?;
1551 assert_optimized_plan_equal!(
1553 plan,
1554 @r"
1555 Projection: test.c, test.b
1556 Projection: test.a, test.b, test.c
1557 TableScan: test, full_filters=[test.a = Int64(1)]
1558 "
1559 )
1560 }
1561
1562 #[test]
1563 fn filter_move_agg() -> Result<()> {
1564 let table_scan = test_table_scan()?;
1565 let plan = LogicalPlanBuilder::from(table_scan)
1566 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1567 .filter(col("a").gt(lit(10i64)))?
1568 .build()?;
1569 assert_optimized_plan_equal!(
1571 plan,
1572 @r"
1573 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1574 TableScan: test, full_filters=[test.a > Int64(10)]
1575 "
1576 )
1577 }
1578
1579 #[test]
1581 fn filter_move_agg_special() -> Result<()> {
1582 let schema = Schema::new(vec![
1583 Field::new("$a", DataType::UInt32, false),
1584 Field::new("$b", DataType::UInt32, false),
1585 Field::new("$c", DataType::UInt32, false),
1586 ]);
1587 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1588
1589 let plan = LogicalPlanBuilder::from(table_scan)
1590 .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1591 .filter(col("$a").gt(lit(10i64)))?
1592 .build()?;
1593 assert_optimized_plan_equal!(
1595 plan,
1596 @r"
1597 Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1598 TableScan: test, full_filters=[test.$a > Int64(10)]
1599 "
1600 )
1601 }
1602
1603 #[test]
1604 fn filter_complex_group_by() -> Result<()> {
1605 let table_scan = test_table_scan()?;
1606 let plan = LogicalPlanBuilder::from(table_scan)
1607 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1608 .filter(col("b").gt(lit(10i64)))?
1609 .build()?;
1610 assert_optimized_plan_equal!(
1611 plan,
1612 @r"
1613 Filter: test.b > Int64(10)
1614 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1615 TableScan: test
1616 "
1617 )
1618 }
1619
1620 #[test]
1621 fn push_agg_need_replace_expr() -> Result<()> {
1622 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1623 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1624 .filter(col("test.b + test.a").gt(lit(10i64)))?
1625 .build()?;
1626 assert_optimized_plan_equal!(
1627 plan,
1628 @r"
1629 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1630 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1631 "
1632 )
1633 }
1634
1635 #[test]
1636 fn filter_keep_agg() -> Result<()> {
1637 let table_scan = test_table_scan()?;
1638 let plan = LogicalPlanBuilder::from(table_scan)
1639 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1640 .filter(col("b").gt(lit(10i64)))?
1641 .build()?;
1642 assert_optimized_plan_equal!(
1644 plan,
1645 @r"
1646 Filter: b > Int64(10)
1647 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1648 TableScan: test
1649 "
1650 )
1651 }
1652
1653 #[test]
1655 fn filter_move_window() -> Result<()> {
1656 let table_scan = test_table_scan()?;
1657
1658 let window = Expr::from(WindowFunction::new(
1659 WindowFunctionDefinition::WindowUDF(
1660 datafusion_functions_window::rank::rank_udwf(),
1661 ),
1662 vec![],
1663 ))
1664 .partition_by(vec![col("a"), col("b")])
1665 .order_by(vec![col("c").sort(true, true)])
1666 .build()
1667 .unwrap();
1668
1669 let plan = LogicalPlanBuilder::from(table_scan)
1670 .window(vec![window])?
1671 .filter(col("b").gt(lit(10i64)))?
1672 .build()?;
1673
1674 assert_optimized_plan_equal!(
1675 plan,
1676 @r"
1677 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1678 TableScan: test, full_filters=[test.b > Int64(10)]
1679 "
1680 )
1681 }
1682
1683 #[test]
1685 fn filter_window_special_identifier() -> Result<()> {
1686 let schema = Schema::new(vec![
1687 Field::new("$a", DataType::UInt32, false),
1688 Field::new("$b", DataType::UInt32, false),
1689 Field::new("$c", DataType::UInt32, false),
1690 ]);
1691 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1692
1693 let window = Expr::from(WindowFunction::new(
1694 WindowFunctionDefinition::WindowUDF(
1695 datafusion_functions_window::rank::rank_udwf(),
1696 ),
1697 vec![],
1698 ))
1699 .partition_by(vec![col("$a"), col("$b")])
1700 .order_by(vec![col("$c").sort(true, true)])
1701 .build()
1702 .unwrap();
1703
1704 let plan = LogicalPlanBuilder::from(table_scan)
1705 .window(vec![window])?
1706 .filter(col("$b").gt(lit(10i64)))?
1707 .build()?;
1708
1709 assert_optimized_plan_equal!(
1710 plan,
1711 @r"
1712 WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1713 TableScan: test, full_filters=[test.$b > Int64(10)]
1714 "
1715 )
1716 }
1717
1718 #[test]
1721 fn filter_move_complex_window() -> Result<()> {
1722 let table_scan = test_table_scan()?;
1723
1724 let window = Expr::from(WindowFunction::new(
1725 WindowFunctionDefinition::WindowUDF(
1726 datafusion_functions_window::rank::rank_udwf(),
1727 ),
1728 vec![],
1729 ))
1730 .partition_by(vec![col("a"), col("b")])
1731 .order_by(vec![col("c").sort(true, true)])
1732 .build()
1733 .unwrap();
1734
1735 let plan = LogicalPlanBuilder::from(table_scan)
1736 .window(vec![window])?
1737 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1738 .build()?;
1739
1740 assert_optimized_plan_equal!(
1741 plan,
1742 @r"
1743 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1744 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1745 "
1746 )
1747 }
1748
1749 #[test]
1751 fn filter_move_partial_window() -> Result<()> {
1752 let table_scan = test_table_scan()?;
1753
1754 let window = Expr::from(WindowFunction::new(
1755 WindowFunctionDefinition::WindowUDF(
1756 datafusion_functions_window::rank::rank_udwf(),
1757 ),
1758 vec![],
1759 ))
1760 .partition_by(vec![col("a")])
1761 .order_by(vec![col("c").sort(true, true)])
1762 .build()
1763 .unwrap();
1764
1765 let plan = LogicalPlanBuilder::from(table_scan)
1766 .window(vec![window])?
1767 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1768 .build()?;
1769
1770 assert_optimized_plan_equal!(
1771 plan,
1772 @r"
1773 Filter: test.b = Int64(1)
1774 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1775 TableScan: test, full_filters=[test.a > Int64(10)]
1776 "
1777 )
1778 }
1779
1780 #[test]
1783 fn filter_expression_keep_window() -> Result<()> {
1784 let table_scan = test_table_scan()?;
1785
1786 let window = Expr::from(WindowFunction::new(
1787 WindowFunctionDefinition::WindowUDF(
1788 datafusion_functions_window::rank::rank_udwf(),
1789 ),
1790 vec![],
1791 ))
1792 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1794 .build()
1795 .unwrap();
1796
1797 let plan = LogicalPlanBuilder::from(table_scan)
1798 .window(vec![window])?
1799 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1802 .build()?;
1803
1804 assert_optimized_plan_equal!(
1805 plan,
1806 @r"
1807 Filter: test.a + test.b > Int64(10)
1808 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1809 TableScan: test
1810 "
1811 )
1812 }
1813
1814 #[test]
1816 fn filter_order_keep_window() -> Result<()> {
1817 let table_scan = test_table_scan()?;
1818
1819 let window = Expr::from(WindowFunction::new(
1820 WindowFunctionDefinition::WindowUDF(
1821 datafusion_functions_window::rank::rank_udwf(),
1822 ),
1823 vec![],
1824 ))
1825 .partition_by(vec![col("a")])
1826 .order_by(vec![col("c").sort(true, true)])
1827 .build()
1828 .unwrap();
1829
1830 let plan = LogicalPlanBuilder::from(table_scan)
1831 .window(vec![window])?
1832 .filter(col("c").gt(lit(10i64)))?
1833 .build()?;
1834
1835 assert_optimized_plan_equal!(
1836 plan,
1837 @r"
1838 Filter: test.c > Int64(10)
1839 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1840 TableScan: test
1841 "
1842 )
1843 }
1844
1845 #[test]
1848 fn filter_multiple_windows_common_partitions() -> Result<()> {
1849 let table_scan = test_table_scan()?;
1850
1851 let window1 = Expr::from(WindowFunction::new(
1852 WindowFunctionDefinition::WindowUDF(
1853 datafusion_functions_window::rank::rank_udwf(),
1854 ),
1855 vec![],
1856 ))
1857 .partition_by(vec![col("a")])
1858 .order_by(vec![col("c").sort(true, true)])
1859 .build()
1860 .unwrap();
1861
1862 let window2 = Expr::from(WindowFunction::new(
1863 WindowFunctionDefinition::WindowUDF(
1864 datafusion_functions_window::rank::rank_udwf(),
1865 ),
1866 vec![],
1867 ))
1868 .partition_by(vec![col("b"), col("a")])
1869 .order_by(vec![col("c").sort(true, true)])
1870 .build()
1871 .unwrap();
1872
1873 let plan = LogicalPlanBuilder::from(table_scan)
1874 .window(vec![window1, window2])?
1875 .filter(col("a").gt(lit(10i64)))? .build()?;
1877
1878 assert_optimized_plan_equal!(
1879 plan,
1880 @r"
1881 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1882 TableScan: test, full_filters=[test.a > Int64(10)]
1883 "
1884 )
1885 }
1886
1887 #[test]
1890 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1891 let table_scan = test_table_scan()?;
1892
1893 let window1 = Expr::from(WindowFunction::new(
1894 WindowFunctionDefinition::WindowUDF(
1895 datafusion_functions_window::rank::rank_udwf(),
1896 ),
1897 vec![],
1898 ))
1899 .partition_by(vec![col("a")])
1900 .order_by(vec![col("c").sort(true, true)])
1901 .build()
1902 .unwrap();
1903
1904 let window2 = Expr::from(WindowFunction::new(
1905 WindowFunctionDefinition::WindowUDF(
1906 datafusion_functions_window::rank::rank_udwf(),
1907 ),
1908 vec![],
1909 ))
1910 .partition_by(vec![col("b"), col("a")])
1911 .order_by(vec![col("c").sort(true, true)])
1912 .build()
1913 .unwrap();
1914
1915 let plan = LogicalPlanBuilder::from(table_scan)
1916 .window(vec![window1, window2])?
1917 .filter(col("b").gt(lit(10i64)))? .build()?;
1919
1920 assert_optimized_plan_equal!(
1921 plan,
1922 @r"
1923 Filter: test.b > Int64(10)
1924 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1925 TableScan: test
1926 "
1927 )
1928 }
1929
1930 #[test]
1932 fn alias() -> Result<()> {
1933 let table_scan = test_table_scan()?;
1934 let plan = LogicalPlanBuilder::from(table_scan)
1935 .project(vec![col("a").alias("b"), col("c")])?
1936 .filter(col("b").eq(lit(1i64)))?
1937 .build()?;
1938 assert_optimized_plan_equal!(
1940 plan,
1941 @r"
1942 Projection: test.a AS b, test.c
1943 TableScan: test, full_filters=[test.a = Int64(1)]
1944 "
1945 )
1946 }
1947
1948 fn add(left: Expr, right: Expr) -> Expr {
1949 Expr::BinaryExpr(BinaryExpr::new(
1950 Box::new(left),
1951 Operator::Plus,
1952 Box::new(right),
1953 ))
1954 }
1955
1956 fn multiply(left: Expr, right: Expr) -> Expr {
1957 Expr::BinaryExpr(BinaryExpr::new(
1958 Box::new(left),
1959 Operator::Multiply,
1960 Box::new(right),
1961 ))
1962 }
1963
1964 #[test]
1966 fn complex_expression() -> Result<()> {
1967 let table_scan = test_table_scan()?;
1968 let plan = LogicalPlanBuilder::from(table_scan)
1969 .project(vec![
1970 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1971 col("c"),
1972 ])?
1973 .filter(col("b").eq(lit(1i64)))?
1974 .build()?;
1975
1976 assert_snapshot!(plan,
1978 @r"
1979 Filter: b = Int64(1)
1980 Projection: test.a * Int32(2) + test.c AS b, test.c
1981 TableScan: test
1982 ",
1983 );
1984 assert_optimized_plan_equal!(
1986 plan,
1987 @r"
1988 Projection: test.a * Int32(2) + test.c AS b, test.c
1989 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1990 "
1991 )
1992 }
1993
1994 #[test]
1996 fn complex_plan() -> Result<()> {
1997 let table_scan = test_table_scan()?;
1998 let plan = LogicalPlanBuilder::from(table_scan)
1999 .project(vec![
2000 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
2001 col("c"),
2002 ])?
2003 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2005 .filter(col("a").eq(lit(1i64)))?
2006 .build()?;
2007
2008 assert_snapshot!(plan,
2010 @r"
2011 Filter: a = Int64(1)
2012 Projection: b * Int32(3) AS a, test.c
2013 Projection: test.a * Int32(2) + test.c AS b, test.c
2014 TableScan: test
2015 ",
2016 );
2017 assert_optimized_plan_equal!(
2019 plan,
2020 @r"
2021 Projection: b * Int32(3) AS a, test.c
2022 Projection: test.a * Int32(2) + test.c AS b, test.c
2023 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2024 "
2025 )
2026 }
2027
2028 #[derive(Debug, PartialEq, Eq, Hash)]
2029 struct NoopPlan {
2030 input: Vec<LogicalPlan>,
2031 schema: DFSchemaRef,
2032 }
2033
2034 impl PartialOrd for NoopPlan {
2036 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2037 self.input
2038 .partial_cmp(&other.input)
2039 .filter(|cmp| *cmp != Ordering::Equal || self == other)
2041 }
2042 }
2043
2044 impl UserDefinedLogicalNodeCore for NoopPlan {
2045 fn name(&self) -> &str {
2046 "NoopPlan"
2047 }
2048
2049 fn inputs(&self) -> Vec<&LogicalPlan> {
2050 self.input.iter().collect()
2051 }
2052
2053 fn schema(&self) -> &DFSchemaRef {
2054 &self.schema
2055 }
2056
2057 fn expressions(&self) -> Vec<Expr> {
2058 self.input
2059 .iter()
2060 .flat_map(|child| child.expressions())
2061 .collect()
2062 }
2063
2064 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2065 HashSet::from_iter(vec!["c".to_string()])
2066 }
2067
2068 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2069 write!(f, "NoopPlan")
2070 }
2071
2072 fn with_exprs_and_inputs(
2073 &self,
2074 _exprs: Vec<Expr>,
2075 inputs: Vec<LogicalPlan>,
2076 ) -> Result<Self> {
2077 Ok(Self {
2078 input: inputs,
2079 schema: Arc::clone(&self.schema),
2080 })
2081 }
2082
2083 fn supports_limit_pushdown(&self) -> bool {
2084 false }
2086 }
2087
2088 #[test]
2089 fn user_defined_plan() -> Result<()> {
2090 let table_scan = test_table_scan()?;
2091
2092 let custom_plan = LogicalPlan::Extension(Extension {
2093 node: Arc::new(NoopPlan {
2094 input: vec![table_scan.clone()],
2095 schema: Arc::clone(table_scan.schema()),
2096 }),
2097 });
2098 let plan = LogicalPlanBuilder::from(custom_plan)
2099 .filter(col("a").eq(lit(1i64)))?
2100 .build()?;
2101
2102 assert_optimized_plan_equal!(
2104 plan,
2105 @r"
2106 NoopPlan
2107 TableScan: test, full_filters=[test.a = Int64(1)]
2108 "
2109 )?;
2110
2111 let custom_plan = LogicalPlan::Extension(Extension {
2112 node: Arc::new(NoopPlan {
2113 input: vec![table_scan.clone()],
2114 schema: Arc::clone(table_scan.schema()),
2115 }),
2116 });
2117 let plan = LogicalPlanBuilder::from(custom_plan)
2118 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2119 .build()?;
2120
2121 assert_optimized_plan_equal!(
2123 plan,
2124 @r"
2125 Filter: test.c = Int64(2)
2126 NoopPlan
2127 TableScan: test, full_filters=[test.a = Int64(1)]
2128 "
2129 )?;
2130
2131 let custom_plan = LogicalPlan::Extension(Extension {
2132 node: Arc::new(NoopPlan {
2133 input: vec![table_scan.clone(), table_scan.clone()],
2134 schema: Arc::clone(table_scan.schema()),
2135 }),
2136 });
2137 let plan = LogicalPlanBuilder::from(custom_plan)
2138 .filter(col("a").eq(lit(1i64)))?
2139 .build()?;
2140
2141 assert_optimized_plan_equal!(
2143 plan,
2144 @r"
2145 NoopPlan
2146 TableScan: test, full_filters=[test.a = Int64(1)]
2147 TableScan: test, full_filters=[test.a = Int64(1)]
2148 "
2149 )?;
2150
2151 let custom_plan = LogicalPlan::Extension(Extension {
2152 node: Arc::new(NoopPlan {
2153 input: vec![table_scan.clone(), table_scan.clone()],
2154 schema: Arc::clone(table_scan.schema()),
2155 }),
2156 });
2157 let plan = LogicalPlanBuilder::from(custom_plan)
2158 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2159 .build()?;
2160
2161 assert_optimized_plan_equal!(
2163 plan,
2164 @r"
2165 Filter: test.c = Int64(2)
2166 NoopPlan
2167 TableScan: test, full_filters=[test.a = Int64(1)]
2168 TableScan: test, full_filters=[test.a = Int64(1)]
2169 "
2170 )
2171 }
2172
2173 #[test]
2176 fn multi_filter() -> Result<()> {
2177 let table_scan = test_table_scan()?;
2179 let plan = LogicalPlanBuilder::from(table_scan)
2180 .project(vec![col("a").alias("b"), col("c")])?
2181 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2182 .filter(col("b").gt(lit(10i64)))?
2183 .filter(col("sum(test.c)").gt(lit(10i64)))?
2184 .build()?;
2185
2186 assert_snapshot!(plan,
2188 @r"
2189 Filter: sum(test.c) > Int64(10)
2190 Filter: b > Int64(10)
2191 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2192 Projection: test.a AS b, test.c
2193 TableScan: test
2194 ",
2195 );
2196 assert_optimized_plan_equal!(
2198 plan,
2199 @r"
2200 Filter: sum(test.c) > Int64(10)
2201 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2202 Projection: test.a AS b, test.c
2203 TableScan: test, full_filters=[test.a > Int64(10)]
2204 "
2205 )
2206 }
2207
2208 #[test]
2211 fn split_filter() -> Result<()> {
2212 let table_scan = test_table_scan()?;
2214 let plan = LogicalPlanBuilder::from(table_scan)
2215 .project(vec![col("a").alias("b"), col("c")])?
2216 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2217 .filter(and(
2218 col("sum(test.c)").gt(lit(10i64)),
2219 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2220 ))?
2221 .build()?;
2222
2223 assert_snapshot!(plan,
2225 @r"
2226 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2227 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2228 Projection: test.a AS b, test.c
2229 TableScan: test
2230 ",
2231 );
2232 assert_optimized_plan_equal!(
2234 plan,
2235 @r"
2236 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2237 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2238 Projection: test.a AS b, test.c
2239 TableScan: test, full_filters=[test.a > Int64(10)]
2240 "
2241 )
2242 }
2243
2244 #[test]
2246 fn double_limit() -> Result<()> {
2247 let table_scan = test_table_scan()?;
2248 let plan = LogicalPlanBuilder::from(table_scan)
2249 .project(vec![col("a"), col("b")])?
2250 .limit(0, Some(20))?
2251 .limit(0, Some(10))?
2252 .project(vec![col("a"), col("b")])?
2253 .filter(col("a").eq(lit(1i64)))?
2254 .build()?;
2255 assert_optimized_plan_equal!(
2257 plan,
2258 @r"
2259 Projection: test.a, test.b
2260 Filter: test.a = Int64(1)
2261 Limit: skip=0, fetch=10
2262 Limit: skip=0, fetch=20
2263 Projection: test.a, test.b
2264 TableScan: test
2265 "
2266 )
2267 }
2268
2269 #[test]
2270 fn union_all() -> Result<()> {
2271 let table_scan = test_table_scan()?;
2272 let table_scan2 = test_table_scan_with_name("test2")?;
2273 let plan = LogicalPlanBuilder::from(table_scan)
2274 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2275 .filter(col("a").eq(lit(1i64)))?
2276 .build()?;
2277 assert_optimized_plan_equal!(
2279 plan,
2280 @r"
2281 Union
2282 TableScan: test, full_filters=[test.a = Int64(1)]
2283 TableScan: test2, full_filters=[test2.a = Int64(1)]
2284 "
2285 )
2286 }
2287
2288 #[test]
2289 fn union_all_on_projection() -> Result<()> {
2290 let table_scan = test_table_scan()?;
2291 let table = LogicalPlanBuilder::from(table_scan)
2292 .project(vec![col("a").alias("b")])?
2293 .alias("test2")?;
2294
2295 let plan = table
2296 .clone()
2297 .union(table.build()?)?
2298 .filter(col("b").eq(lit(1i64)))?
2299 .build()?;
2300
2301 assert_optimized_plan_equal!(
2303 plan,
2304 @r"
2305 Union
2306 SubqueryAlias: test2
2307 Projection: test.a AS b
2308 TableScan: test, full_filters=[test.a = Int64(1)]
2309 SubqueryAlias: test2
2310 Projection: test.a AS b
2311 TableScan: test, full_filters=[test.a = Int64(1)]
2312 "
2313 )
2314 }
2315
2316 #[test]
2317 fn test_union_different_schema() -> Result<()> {
2318 let left = LogicalPlanBuilder::from(test_table_scan()?)
2319 .project(vec![col("a"), col("b"), col("c")])?
2320 .build()?;
2321
2322 let schema = Schema::new(vec![
2323 Field::new("d", DataType::UInt32, false),
2324 Field::new("e", DataType::UInt32, false),
2325 Field::new("f", DataType::UInt32, false),
2326 ]);
2327 let right = table_scan(Some("test1"), &schema, None)?
2328 .project(vec![col("d"), col("e"), col("f")])?
2329 .build()?;
2330 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2331 let plan = LogicalPlanBuilder::from(left)
2332 .cross_join(right)?
2333 .project(vec![col("test.a"), col("test1.d")])?
2334 .filter(filter)?
2335 .build()?;
2336
2337 assert_optimized_plan_equal!(
2338 plan,
2339 @r"
2340 Projection: test.a, test1.d
2341 Cross Join:
2342 Projection: test.a, test.b, test.c
2343 TableScan: test, full_filters=[test.a = Int32(1)]
2344 Projection: test1.d, test1.e, test1.f
2345 TableScan: test1, full_filters=[test1.d > Int32(2)]
2346 "
2347 )
2348 }
2349
2350 #[test]
2351 fn test_project_same_name_different_qualifier() -> Result<()> {
2352 let table_scan = test_table_scan()?;
2353 let left = LogicalPlanBuilder::from(table_scan)
2354 .project(vec![col("a"), col("b"), col("c")])?
2355 .build()?;
2356 let right_table_scan = test_table_scan_with_name("test1")?;
2357 let right = LogicalPlanBuilder::from(right_table_scan)
2358 .project(vec![col("a"), col("b"), col("c")])?
2359 .build()?;
2360 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2361 let plan = LogicalPlanBuilder::from(left)
2362 .cross_join(right)?
2363 .project(vec![col("test.a"), col("test1.a")])?
2364 .filter(filter)?
2365 .build()?;
2366
2367 assert_optimized_plan_equal!(
2368 plan,
2369 @r"
2370 Projection: test.a, test1.a
2371 Cross Join:
2372 Projection: test.a, test.b, test.c
2373 TableScan: test, full_filters=[test.a = Int32(1)]
2374 Projection: test1.a, test1.b, test1.c
2375 TableScan: test1, full_filters=[test1.a > Int32(2)]
2376 "
2377 )
2378 }
2379
2380 #[test]
2382 fn filter_2_breaks_limits() -> Result<()> {
2383 let table_scan = test_table_scan()?;
2384 let plan = LogicalPlanBuilder::from(table_scan)
2385 .project(vec![col("a")])?
2386 .filter(col("a").lt_eq(lit(1i64)))?
2387 .limit(0, Some(1))?
2388 .project(vec![col("a")])?
2389 .filter(col("a").gt_eq(lit(1i64)))?
2390 .build()?;
2391 assert_snapshot!(plan,
2395 @r"
2396 Filter: test.a >= Int64(1)
2397 Projection: test.a
2398 Limit: skip=0, fetch=1
2399 Filter: test.a <= Int64(1)
2400 Projection: test.a
2401 TableScan: test
2402 ",
2403 );
2404 assert_optimized_plan_equal!(
2405 plan,
2406 @r"
2407 Projection: test.a
2408 Filter: test.a >= Int64(1)
2409 Limit: skip=0, fetch=1
2410 Projection: test.a
2411 TableScan: test, full_filters=[test.a <= Int64(1)]
2412 "
2413 )
2414 }
2415
2416 #[test]
2418 fn two_filters_on_same_depth() -> Result<()> {
2419 let table_scan = test_table_scan()?;
2420 let plan = LogicalPlanBuilder::from(table_scan)
2421 .limit(0, Some(1))?
2422 .filter(col("a").lt_eq(lit(1i64)))?
2423 .filter(col("a").gt_eq(lit(1i64)))?
2424 .project(vec![col("a")])?
2425 .build()?;
2426
2427 assert_snapshot!(plan,
2429 @r"
2430 Projection: test.a
2431 Filter: test.a >= Int64(1)
2432 Filter: test.a <= Int64(1)
2433 Limit: skip=0, fetch=1
2434 TableScan: test
2435 ",
2436 );
2437 assert_optimized_plan_equal!(
2438 plan,
2439 @r"
2440 Projection: test.a
2441 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2442 Limit: skip=0, fetch=1
2443 TableScan: test
2444 "
2445 )
2446 }
2447
2448 #[test]
2451 fn filters_user_defined_node() -> Result<()> {
2452 let table_scan = test_table_scan()?;
2453 let plan = LogicalPlanBuilder::from(table_scan)
2454 .filter(col("a").lt_eq(lit(1i64)))?
2455 .build()?;
2456
2457 let plan = user_defined::new(plan);
2458
2459 assert_snapshot!(plan,
2461 @r"
2462 TestUserDefined
2463 Filter: test.a <= Int64(1)
2464 TableScan: test
2465 ",
2466 );
2467 assert_optimized_plan_equal!(
2468 plan,
2469 @r"
2470 TestUserDefined
2471 TableScan: test, full_filters=[test.a <= Int64(1)]
2472 "
2473 )
2474 }
2475
2476 #[test]
2478 fn filter_on_join_on_common_independent() -> Result<()> {
2479 let table_scan = test_table_scan()?;
2480 let left = LogicalPlanBuilder::from(table_scan).build()?;
2481 let right_table_scan = test_table_scan_with_name("test2")?;
2482 let right = LogicalPlanBuilder::from(right_table_scan)
2483 .project(vec![col("a")])?
2484 .build()?;
2485 let plan = LogicalPlanBuilder::from(left)
2486 .join(
2487 right,
2488 JoinType::Inner,
2489 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2490 None,
2491 )?
2492 .filter(col("test.a").lt_eq(lit(1i64)))?
2493 .build()?;
2494
2495 assert_snapshot!(plan,
2497 @r"
2498 Filter: test.a <= Int64(1)
2499 Inner Join: test.a = test2.a
2500 TableScan: test
2501 Projection: test2.a
2502 TableScan: test2
2503 ",
2504 );
2505 assert_optimized_plan_equal!(
2507 plan,
2508 @r"
2509 Inner Join: test.a = test2.a
2510 TableScan: test, full_filters=[test.a <= Int64(1)]
2511 Projection: test2.a
2512 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2513 "
2514 )
2515 }
2516
2517 #[test]
2519 fn filter_using_join_on_common_independent() -> Result<()> {
2520 let table_scan = test_table_scan()?;
2521 let left = LogicalPlanBuilder::from(table_scan).build()?;
2522 let right_table_scan = test_table_scan_with_name("test2")?;
2523 let right = LogicalPlanBuilder::from(right_table_scan)
2524 .project(vec![col("a")])?
2525 .build()?;
2526 let plan = LogicalPlanBuilder::from(left)
2527 .join_using(
2528 right,
2529 JoinType::Inner,
2530 vec![Column::from_name("a".to_string())],
2531 )?
2532 .filter(col("a").lt_eq(lit(1i64)))?
2533 .build()?;
2534
2535 assert_snapshot!(plan,
2537 @r"
2538 Filter: test.a <= Int64(1)
2539 Inner Join: Using test.a = test2.a
2540 TableScan: test
2541 Projection: test2.a
2542 TableScan: test2
2543 ",
2544 );
2545 assert_optimized_plan_equal!(
2547 plan,
2548 @r"
2549 Inner Join: Using test.a = test2.a
2550 TableScan: test, full_filters=[test.a <= Int64(1)]
2551 Projection: test2.a
2552 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2553 "
2554 )
2555 }
2556
2557 #[test]
2559 fn filter_join_on_common_dependent() -> Result<()> {
2560 let table_scan = test_table_scan()?;
2561 let left = LogicalPlanBuilder::from(table_scan)
2562 .project(vec![col("a"), col("c")])?
2563 .build()?;
2564 let right_table_scan = test_table_scan_with_name("test2")?;
2565 let right = LogicalPlanBuilder::from(right_table_scan)
2566 .project(vec![col("a"), col("b")])?
2567 .build()?;
2568 let plan = LogicalPlanBuilder::from(left)
2569 .join(
2570 right,
2571 JoinType::Inner,
2572 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2573 None,
2574 )?
2575 .filter(col("c").lt_eq(col("b")))?
2576 .build()?;
2577
2578 assert_snapshot!(plan,
2580 @r"
2581 Filter: test.c <= test2.b
2582 Inner Join: test.a = test2.a
2583 Projection: test.a, test.c
2584 TableScan: test
2585 Projection: test2.a, test2.b
2586 TableScan: test2
2587 ",
2588 );
2589 assert_optimized_plan_equal!(
2591 plan,
2592 @r"
2593 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2594 Projection: test.a, test.c
2595 TableScan: test
2596 Projection: test2.a, test2.b
2597 TableScan: test2
2598 "
2599 )
2600 }
2601
2602 #[test]
2604 fn filter_join_on_one_side() -> Result<()> {
2605 let table_scan = test_table_scan()?;
2606 let left = LogicalPlanBuilder::from(table_scan)
2607 .project(vec![col("a"), col("b")])?
2608 .build()?;
2609 let table_scan_right = test_table_scan_with_name("test2")?;
2610 let right = LogicalPlanBuilder::from(table_scan_right)
2611 .project(vec![col("a"), col("c")])?
2612 .build()?;
2613
2614 let plan = LogicalPlanBuilder::from(left)
2615 .join(
2616 right,
2617 JoinType::Inner,
2618 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2619 None,
2620 )?
2621 .filter(col("b").lt_eq(lit(1i64)))?
2622 .build()?;
2623
2624 assert_snapshot!(plan,
2626 @r"
2627 Filter: test.b <= Int64(1)
2628 Inner Join: test.a = test2.a
2629 Projection: test.a, test.b
2630 TableScan: test
2631 Projection: test2.a, test2.c
2632 TableScan: test2
2633 ",
2634 );
2635 assert_optimized_plan_equal!(
2636 plan,
2637 @r"
2638 Inner Join: test.a = test2.a
2639 Projection: test.a, test.b
2640 TableScan: test, full_filters=[test.b <= Int64(1)]
2641 Projection: test2.a, test2.c
2642 TableScan: test2
2643 "
2644 )
2645 }
2646
2647 #[test]
2650 fn filter_using_left_join() -> Result<()> {
2651 let table_scan = test_table_scan()?;
2652 let left = LogicalPlanBuilder::from(table_scan).build()?;
2653 let right_table_scan = test_table_scan_with_name("test2")?;
2654 let right = LogicalPlanBuilder::from(right_table_scan)
2655 .project(vec![col("a")])?
2656 .build()?;
2657 let plan = LogicalPlanBuilder::from(left)
2658 .join_using(
2659 right,
2660 JoinType::Left,
2661 vec![Column::from_name("a".to_string())],
2662 )?
2663 .filter(col("test2.a").lt_eq(lit(1i64)))?
2664 .build()?;
2665
2666 assert_snapshot!(plan,
2668 @r"
2669 Filter: test2.a <= Int64(1)
2670 Left Join: Using test.a = test2.a
2671 TableScan: test
2672 Projection: test2.a
2673 TableScan: test2
2674 ",
2675 );
2676 assert_optimized_plan_equal!(
2678 plan,
2679 @r"
2680 Filter: test2.a <= Int64(1)
2681 Left Join: Using test.a = test2.a
2682 TableScan: test, full_filters=[test.a <= Int64(1)]
2683 Projection: test2.a
2684 TableScan: test2
2685 "
2686 )
2687 }
2688
2689 #[test]
2691 fn filter_using_right_join() -> Result<()> {
2692 let table_scan = test_table_scan()?;
2693 let left = LogicalPlanBuilder::from(table_scan).build()?;
2694 let right_table_scan = test_table_scan_with_name("test2")?;
2695 let right = LogicalPlanBuilder::from(right_table_scan)
2696 .project(vec![col("a")])?
2697 .build()?;
2698 let plan = LogicalPlanBuilder::from(left)
2699 .join_using(
2700 right,
2701 JoinType::Right,
2702 vec![Column::from_name("a".to_string())],
2703 )?
2704 .filter(col("test.a").lt_eq(lit(1i64)))?
2705 .build()?;
2706
2707 assert_snapshot!(plan,
2709 @r"
2710 Filter: test.a <= Int64(1)
2711 Right Join: Using test.a = test2.a
2712 TableScan: test
2713 Projection: test2.a
2714 TableScan: test2
2715 ",
2716 );
2717 assert_optimized_plan_equal!(
2719 plan,
2720 @r"
2721 Filter: test.a <= Int64(1)
2722 Right Join: Using test.a = test2.a
2723 TableScan: test
2724 Projection: test2.a
2725 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2726 "
2727 )
2728 }
2729
2730 #[test]
2733 fn filter_using_left_join_on_common() -> Result<()> {
2734 let table_scan = test_table_scan()?;
2735 let left = LogicalPlanBuilder::from(table_scan).build()?;
2736 let right_table_scan = test_table_scan_with_name("test2")?;
2737 let right = LogicalPlanBuilder::from(right_table_scan)
2738 .project(vec![col("a")])?
2739 .build()?;
2740 let plan = LogicalPlanBuilder::from(left)
2741 .join_using(
2742 right,
2743 JoinType::Left,
2744 vec![Column::from_name("a".to_string())],
2745 )?
2746 .filter(col("a").lt_eq(lit(1i64)))?
2747 .build()?;
2748
2749 assert_snapshot!(plan,
2751 @r"
2752 Filter: test.a <= Int64(1)
2753 Left Join: Using test.a = test2.a
2754 TableScan: test
2755 Projection: test2.a
2756 TableScan: test2
2757 ",
2758 );
2759 assert_optimized_plan_equal!(
2761 plan,
2762 @r"
2763 Left Join: Using test.a = test2.a
2764 TableScan: test, full_filters=[test.a <= Int64(1)]
2765 Projection: test2.a
2766 TableScan: test2
2767 "
2768 )
2769 }
2770
2771 #[test]
2774 fn filter_using_right_join_on_common() -> Result<()> {
2775 let table_scan = test_table_scan()?;
2776 let left = LogicalPlanBuilder::from(table_scan).build()?;
2777 let right_table_scan = test_table_scan_with_name("test2")?;
2778 let right = LogicalPlanBuilder::from(right_table_scan)
2779 .project(vec![col("a")])?
2780 .build()?;
2781 let plan = LogicalPlanBuilder::from(left)
2782 .join_using(
2783 right,
2784 JoinType::Right,
2785 vec![Column::from_name("a".to_string())],
2786 )?
2787 .filter(col("test2.a").lt_eq(lit(1i64)))?
2788 .build()?;
2789
2790 assert_snapshot!(plan,
2792 @r"
2793 Filter: test2.a <= Int64(1)
2794 Right Join: Using test.a = test2.a
2795 TableScan: test
2796 Projection: test2.a
2797 TableScan: test2
2798 ",
2799 );
2800 assert_optimized_plan_equal!(
2802 plan,
2803 @r"
2804 Right Join: Using test.a = test2.a
2805 TableScan: test
2806 Projection: test2.a
2807 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2808 "
2809 )
2810 }
2811
2812 #[test]
2814 fn join_on_with_filter() -> Result<()> {
2815 let table_scan = test_table_scan()?;
2816 let left = LogicalPlanBuilder::from(table_scan)
2817 .project(vec![col("a"), col("b"), col("c")])?
2818 .build()?;
2819 let right_table_scan = test_table_scan_with_name("test2")?;
2820 let right = LogicalPlanBuilder::from(right_table_scan)
2821 .project(vec![col("a"), col("b"), col("c")])?
2822 .build()?;
2823 let filter = col("test.c")
2824 .gt(lit(1u32))
2825 .and(col("test.b").lt(col("test2.b")))
2826 .and(col("test2.c").gt(lit(4u32)));
2827 let plan = LogicalPlanBuilder::from(left)
2828 .join(
2829 right,
2830 JoinType::Inner,
2831 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2832 Some(filter),
2833 )?
2834 .build()?;
2835
2836 assert_snapshot!(plan,
2838 @r"
2839 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2840 Projection: test.a, test.b, test.c
2841 TableScan: test
2842 Projection: test2.a, test2.b, test2.c
2843 TableScan: test2
2844 ",
2845 );
2846 assert_optimized_plan_equal!(
2847 plan,
2848 @r"
2849 Inner Join: test.a = test2.a Filter: test.b < test2.b
2850 Projection: test.a, test.b, test.c
2851 TableScan: test, full_filters=[test.c > UInt32(1)]
2852 Projection: test2.a, test2.b, test2.c
2853 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2854 "
2855 )
2856 }
2857
2858 #[test]
2860 fn join_filter_removed() -> Result<()> {
2861 let table_scan = test_table_scan()?;
2862 let left = LogicalPlanBuilder::from(table_scan)
2863 .project(vec![col("a"), col("b"), col("c")])?
2864 .build()?;
2865 let right_table_scan = test_table_scan_with_name("test2")?;
2866 let right = LogicalPlanBuilder::from(right_table_scan)
2867 .project(vec![col("a"), col("b"), col("c")])?
2868 .build()?;
2869 let filter = col("test.b")
2870 .gt(lit(1u32))
2871 .and(col("test2.c").gt(lit(4u32)));
2872 let plan = LogicalPlanBuilder::from(left)
2873 .join(
2874 right,
2875 JoinType::Inner,
2876 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2877 Some(filter),
2878 )?
2879 .build()?;
2880
2881 assert_snapshot!(plan,
2883 @r"
2884 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2885 Projection: test.a, test.b, test.c
2886 TableScan: test
2887 Projection: test2.a, test2.b, test2.c
2888 TableScan: test2
2889 ",
2890 );
2891 assert_optimized_plan_equal!(
2892 plan,
2893 @r"
2894 Inner Join: test.a = test2.a
2895 Projection: test.a, test.b, test.c
2896 TableScan: test, full_filters=[test.b > UInt32(1)]
2897 Projection: test2.a, test2.b, test2.c
2898 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2899 "
2900 )
2901 }
2902
2903 #[test]
2905 fn join_filter_on_common() -> Result<()> {
2906 let table_scan = test_table_scan()?;
2907 let left = LogicalPlanBuilder::from(table_scan)
2908 .project(vec![col("a")])?
2909 .build()?;
2910 let right_table_scan = test_table_scan_with_name("test2")?;
2911 let right = LogicalPlanBuilder::from(right_table_scan)
2912 .project(vec![col("b")])?
2913 .build()?;
2914 let filter = col("test.a").gt(lit(1u32));
2915 let plan = LogicalPlanBuilder::from(left)
2916 .join(
2917 right,
2918 JoinType::Inner,
2919 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2920 Some(filter),
2921 )?
2922 .build()?;
2923
2924 assert_snapshot!(plan,
2926 @r"
2927 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2928 Projection: test.a
2929 TableScan: test
2930 Projection: test2.b
2931 TableScan: test2
2932 ",
2933 );
2934 assert_optimized_plan_equal!(
2935 plan,
2936 @r"
2937 Inner Join: test.a = test2.b
2938 Projection: test.a
2939 TableScan: test, full_filters=[test.a > UInt32(1)]
2940 Projection: test2.b
2941 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2942 "
2943 )
2944 }
2945
2946 #[test]
2948 fn left_join_on_with_filter() -> Result<()> {
2949 let table_scan = test_table_scan()?;
2950 let left = LogicalPlanBuilder::from(table_scan)
2951 .project(vec![col("a"), col("b"), col("c")])?
2952 .build()?;
2953 let right_table_scan = test_table_scan_with_name("test2")?;
2954 let right = LogicalPlanBuilder::from(right_table_scan)
2955 .project(vec![col("a"), col("b"), col("c")])?
2956 .build()?;
2957 let filter = col("test.a")
2958 .gt(lit(1u32))
2959 .and(col("test.b").lt(col("test2.b")))
2960 .and(col("test2.c").gt(lit(4u32)));
2961 let plan = LogicalPlanBuilder::from(left)
2962 .join(
2963 right,
2964 JoinType::Left,
2965 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2966 Some(filter),
2967 )?
2968 .build()?;
2969
2970 assert_snapshot!(plan,
2972 @r"
2973 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2974 Projection: test.a, test.b, test.c
2975 TableScan: test
2976 Projection: test2.a, test2.b, test2.c
2977 TableScan: test2
2978 ",
2979 );
2980 assert_optimized_plan_equal!(
2981 plan,
2982 @r"
2983 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2984 Projection: test.a, test.b, test.c
2985 TableScan: test
2986 Projection: test2.a, test2.b, test2.c
2987 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2988 "
2989 )
2990 }
2991
2992 #[test]
2994 fn right_join_on_with_filter() -> Result<()> {
2995 let table_scan = test_table_scan()?;
2996 let left = LogicalPlanBuilder::from(table_scan)
2997 .project(vec![col("a"), col("b"), col("c")])?
2998 .build()?;
2999 let right_table_scan = test_table_scan_with_name("test2")?;
3000 let right = LogicalPlanBuilder::from(right_table_scan)
3001 .project(vec![col("a"), col("b"), col("c")])?
3002 .build()?;
3003 let filter = col("test.a")
3004 .gt(lit(1u32))
3005 .and(col("test.b").lt(col("test2.b")))
3006 .and(col("test2.c").gt(lit(4u32)));
3007 let plan = LogicalPlanBuilder::from(left)
3008 .join(
3009 right,
3010 JoinType::Right,
3011 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3012 Some(filter),
3013 )?
3014 .build()?;
3015
3016 assert_snapshot!(plan,
3018 @r"
3019 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3020 Projection: test.a, test.b, test.c
3021 TableScan: test
3022 Projection: test2.a, test2.b, test2.c
3023 TableScan: test2
3024 ",
3025 );
3026 assert_optimized_plan_equal!(
3027 plan,
3028 @r"
3029 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3030 Projection: test.a, test.b, test.c
3031 TableScan: test, full_filters=[test.a > UInt32(1)]
3032 Projection: test2.a, test2.b, test2.c
3033 TableScan: test2
3034 "
3035 )
3036 }
3037
3038 #[test]
3040 fn full_join_on_with_filter() -> Result<()> {
3041 let table_scan = test_table_scan()?;
3042 let left = LogicalPlanBuilder::from(table_scan)
3043 .project(vec![col("a"), col("b"), col("c")])?
3044 .build()?;
3045 let right_table_scan = test_table_scan_with_name("test2")?;
3046 let right = LogicalPlanBuilder::from(right_table_scan)
3047 .project(vec![col("a"), col("b"), col("c")])?
3048 .build()?;
3049 let filter = col("test.a")
3050 .gt(lit(1u32))
3051 .and(col("test.b").lt(col("test2.b")))
3052 .and(col("test2.c").gt(lit(4u32)));
3053 let plan = LogicalPlanBuilder::from(left)
3054 .join(
3055 right,
3056 JoinType::Full,
3057 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3058 Some(filter),
3059 )?
3060 .build()?;
3061
3062 assert_snapshot!(plan,
3064 @r"
3065 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3066 Projection: test.a, test.b, test.c
3067 TableScan: test
3068 Projection: test2.a, test2.b, test2.c
3069 TableScan: test2
3070 ",
3071 );
3072 assert_optimized_plan_equal!(
3073 plan,
3074 @r"
3075 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3076 Projection: test.a, test.b, test.c
3077 TableScan: test
3078 Projection: test2.a, test2.b, test2.c
3079 TableScan: test2
3080 "
3081 )
3082 }
3083
3084 struct PushDownProvider {
3085 pub filter_support: TableProviderFilterPushDown,
3086 }
3087
3088 #[async_trait]
3089 impl TableSource for PushDownProvider {
3090 fn schema(&self) -> SchemaRef {
3091 Arc::new(Schema::new(vec![
3092 Field::new("a", DataType::Int32, true),
3093 Field::new("b", DataType::Int32, true),
3094 ]))
3095 }
3096
3097 fn table_type(&self) -> TableType {
3098 TableType::Base
3099 }
3100
3101 fn supports_filters_pushdown(
3102 &self,
3103 filters: &[&Expr],
3104 ) -> Result<Vec<TableProviderFilterPushDown>> {
3105 Ok((0..filters.len())
3106 .map(|_| self.filter_support.clone())
3107 .collect())
3108 }
3109
3110 fn as_any(&self) -> &dyn Any {
3111 self
3112 }
3113 }
3114
3115 fn table_scan_with_pushdown_provider_builder(
3116 filter_support: TableProviderFilterPushDown,
3117 filters: Vec<Expr>,
3118 projection: Option<Vec<usize>>,
3119 ) -> Result<LogicalPlanBuilder> {
3120 let test_provider = PushDownProvider { filter_support };
3121
3122 let table_scan = LogicalPlan::TableScan(TableScan {
3123 table_name: "test".into(),
3124 filters,
3125 projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3126 projection,
3127 source: Arc::new(test_provider),
3128 fetch: None,
3129 });
3130
3131 Ok(LogicalPlanBuilder::from(table_scan))
3132 }
3133
3134 fn table_scan_with_pushdown_provider(
3135 filter_support: TableProviderFilterPushDown,
3136 ) -> Result<LogicalPlan> {
3137 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3138 .filter(col("a").eq(lit(1i64)))?
3139 .build()
3140 }
3141
3142 #[test]
3143 fn filter_with_table_provider_exact() -> Result<()> {
3144 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3145
3146 assert_optimized_plan_equal!(
3147 plan,
3148 @"TableScan: test, full_filters=[a = Int64(1)]"
3149 )
3150 }
3151
3152 #[test]
3153 fn filter_with_table_provider_inexact() -> Result<()> {
3154 let plan =
3155 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3156
3157 assert_optimized_plan_equal!(
3158 plan,
3159 @r"
3160 Filter: a = Int64(1)
3161 TableScan: test, partial_filters=[a = Int64(1)]
3162 "
3163 )
3164 }
3165
3166 #[test]
3167 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3168 let plan =
3169 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3170
3171 let optimized_plan = PushDownFilter::new()
3172 .rewrite(plan, &OptimizerContext::new())
3173 .expect("failed to optimize plan")
3174 .data;
3175
3176 assert_optimized_plan_equal!(
3179 optimized_plan,
3180 @r"
3181 Filter: a = Int64(1)
3182 TableScan: test, partial_filters=[a = Int64(1)]
3183 "
3184 )
3185 }
3186
3187 #[test]
3188 fn filter_with_table_provider_unsupported() -> Result<()> {
3189 let plan =
3190 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3191
3192 assert_optimized_plan_equal!(
3193 plan,
3194 @r"
3195 Filter: a = Int64(1)
3196 TableScan: test
3197 "
3198 )
3199 }
3200
3201 #[test]
3202 fn multi_combined_filter() -> Result<()> {
3203 let plan = table_scan_with_pushdown_provider_builder(
3204 TableProviderFilterPushDown::Inexact,
3205 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3206 Some(vec![0]),
3207 )?
3208 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3209 .project(vec![col("a"), col("b")])?
3210 .build()?;
3211
3212 assert_optimized_plan_equal!(
3213 plan,
3214 @r"
3215 Projection: a, b
3216 Filter: a = Int64(10) AND b > Int64(11)
3217 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3218 "
3219 )
3220 }
3221
3222 #[test]
3223 fn multi_combined_filter_exact() -> Result<()> {
3224 let plan = table_scan_with_pushdown_provider_builder(
3225 TableProviderFilterPushDown::Exact,
3226 vec![],
3227 Some(vec![0]),
3228 )?
3229 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3230 .project(vec![col("a"), col("b")])?
3231 .build()?;
3232
3233 assert_optimized_plan_equal!(
3234 plan,
3235 @r"
3236 Projection: a, b
3237 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3238 "
3239 )
3240 }
3241
3242 #[test]
3243 fn test_filter_with_alias() -> Result<()> {
3244 let table_scan = test_table_scan()?;
3248 let plan = LogicalPlanBuilder::from(table_scan)
3249 .project(vec![col("a").alias("b"), col("c")])?
3250 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3251 .build()?;
3252
3253 assert_snapshot!(plan,
3255 @r"
3256 Filter: b > Int64(10) AND test.c > Int64(10)
3257 Projection: test.a AS b, test.c
3258 TableScan: test
3259 ",
3260 );
3261 assert_optimized_plan_equal!(
3263 plan,
3264 @r"
3265 Projection: test.a AS b, test.c
3266 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3267 "
3268 )
3269 }
3270
3271 #[test]
3272 fn test_filter_with_alias_2() -> Result<()> {
3273 let table_scan = test_table_scan()?;
3277 let plan = LogicalPlanBuilder::from(table_scan)
3278 .project(vec![col("a").alias("b"), col("c")])?
3279 .project(vec![col("b"), col("c")])?
3280 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3281 .build()?;
3282
3283 assert_snapshot!(plan,
3285 @r"
3286 Filter: b > Int64(10) AND test.c > Int64(10)
3287 Projection: b, test.c
3288 Projection: test.a AS b, test.c
3289 TableScan: test
3290 ",
3291 );
3292 assert_optimized_plan_equal!(
3294 plan,
3295 @r"
3296 Projection: b, test.c
3297 Projection: test.a AS b, test.c
3298 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3299 "
3300 )
3301 }
3302
3303 #[test]
3304 fn test_filter_with_multi_alias() -> Result<()> {
3305 let table_scan = test_table_scan()?;
3306 let plan = LogicalPlanBuilder::from(table_scan)
3307 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3308 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3309 .build()?;
3310
3311 assert_snapshot!(plan,
3313 @r"
3314 Filter: b > Int64(10) AND d > Int64(10)
3315 Projection: test.a AS b, test.c AS d
3316 TableScan: test
3317 ",
3318 );
3319 assert_optimized_plan_equal!(
3321 plan,
3322 @r"
3323 Projection: test.a AS b, test.c AS d
3324 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3325 "
3326 )
3327 }
3328
3329 #[test]
3331 fn join_filter_with_alias() -> Result<()> {
3332 let table_scan = test_table_scan()?;
3333 let left = LogicalPlanBuilder::from(table_scan)
3334 .project(vec![col("a").alias("c")])?
3335 .build()?;
3336 let right_table_scan = test_table_scan_with_name("test2")?;
3337 let right = LogicalPlanBuilder::from(right_table_scan)
3338 .project(vec![col("b").alias("d")])?
3339 .build()?;
3340 let filter = col("c").gt(lit(1u32));
3341 let plan = LogicalPlanBuilder::from(left)
3342 .join(
3343 right,
3344 JoinType::Inner,
3345 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3346 Some(filter),
3347 )?
3348 .build()?;
3349
3350 assert_snapshot!(plan,
3351 @r"
3352 Inner Join: c = d Filter: c > UInt32(1)
3353 Projection: test.a AS c
3354 TableScan: test
3355 Projection: test2.b AS d
3356 TableScan: test2
3357 ",
3358 );
3359 assert_optimized_plan_equal!(
3361 plan,
3362 @r"
3363 Inner Join: c = d
3364 Projection: test.a AS c
3365 TableScan: test, full_filters=[test.a > UInt32(1)]
3366 Projection: test2.b AS d
3367 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3368 "
3369 )
3370 }
3371
3372 #[test]
3373 fn test_in_filter_with_alias() -> Result<()> {
3374 let table_scan = test_table_scan()?;
3378 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3379 let plan = LogicalPlanBuilder::from(table_scan)
3380 .project(vec![col("a").alias("b"), col("c")])?
3381 .filter(in_list(col("b"), filter_value, false))?
3382 .build()?;
3383
3384 assert_snapshot!(plan,
3386 @r"
3387 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3388 Projection: test.a AS b, test.c
3389 TableScan: test
3390 ",
3391 );
3392 assert_optimized_plan_equal!(
3394 plan,
3395 @r"
3396 Projection: test.a AS b, test.c
3397 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3398 "
3399 )
3400 }
3401
3402 #[test]
3403 fn test_in_filter_with_alias_2() -> Result<()> {
3404 let table_scan = test_table_scan()?;
3408 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3409 let plan = LogicalPlanBuilder::from(table_scan)
3410 .project(vec![col("a").alias("b"), col("c")])?
3411 .project(vec![col("b"), col("c")])?
3412 .filter(in_list(col("b"), filter_value, false))?
3413 .build()?;
3414
3415 assert_snapshot!(plan,
3417 @r"
3418 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3419 Projection: b, test.c
3420 Projection: test.a AS b, test.c
3421 TableScan: test
3422 ",
3423 );
3424 assert_optimized_plan_equal!(
3426 plan,
3427 @r"
3428 Projection: b, test.c
3429 Projection: test.a AS b, test.c
3430 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3431 "
3432 )
3433 }
3434
3435 #[test]
3436 fn test_in_subquery_with_alias() -> Result<()> {
3437 let table_scan = test_table_scan()?;
3440 let table_scan_sq = test_table_scan_with_name("sq")?;
3441 let subplan = Arc::new(
3442 LogicalPlanBuilder::from(table_scan_sq)
3443 .project(vec![col("c")])?
3444 .build()?,
3445 );
3446 let plan = LogicalPlanBuilder::from(table_scan)
3447 .project(vec![col("a").alias("b"), col("c")])?
3448 .filter(in_subquery(col("b"), subplan))?
3449 .build()?;
3450
3451 assert_snapshot!(plan,
3453 @r"
3454 Filter: b IN (<subquery>)
3455 Subquery:
3456 Projection: sq.c
3457 TableScan: sq
3458 Projection: test.a AS b, test.c
3459 TableScan: test
3460 ",
3461 );
3462 assert_optimized_plan_equal!(
3464 plan,
3465 @r"
3466 Projection: test.a AS b, test.c
3467 TableScan: test, full_filters=[test.a IN (<subquery>)]
3468 Subquery:
3469 Projection: sq.c
3470 TableScan: sq
3471 "
3472 )
3473 }
3474
3475 #[test]
3476 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3477 let plan = LogicalPlanBuilder::empty(true)
3479 .project(vec![lit(0i64).alias("a")])?
3480 .alias("b")?
3481 .project(vec![col("b.a")])?
3482 .alias("b")?
3483 .filter(col("b.a").eq(lit(1i64)))?
3484 .project(vec![col("b.a")])?
3485 .build()?;
3486
3487 assert_snapshot!(plan,
3488 @r"
3489 Projection: b.a
3490 Filter: b.a = Int64(1)
3491 SubqueryAlias: b
3492 Projection: b.a
3493 SubqueryAlias: b
3494 Projection: Int64(0) AS a
3495 EmptyRelation: rows=1
3496 ",
3497 );
3498 assert_optimized_plan_equal!(
3501 plan,
3502 @r"
3503 Projection: b.a
3504 SubqueryAlias: b
3505 Projection: b.a
3506 SubqueryAlias: b
3507 Projection: Int64(0) AS a
3508 Filter: Int64(0) = Int64(1)
3509 EmptyRelation: rows=1
3510 "
3511 )
3512 }
3513
3514 #[test]
3515 fn test_crossjoin_with_or_clause() -> Result<()> {
3516 let table_scan = test_table_scan()?;
3518 let left = LogicalPlanBuilder::from(table_scan)
3519 .project(vec![col("a"), col("b"), col("c")])?
3520 .build()?;
3521 let right_table_scan = test_table_scan_with_name("test1")?;
3522 let right = LogicalPlanBuilder::from(right_table_scan)
3523 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3524 .build()?;
3525 let filter = or(
3526 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3527 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3528 );
3529 let plan = LogicalPlanBuilder::from(left)
3530 .cross_join(right)?
3531 .filter(filter)?
3532 .build()?;
3533
3534 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3535 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3536 Projection: test.a, test.b, test.c
3537 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3538 Projection: test1.a AS d, test1.a AS e
3539 TableScan: test1
3540 ")?;
3541
3542 let optimized_plan = PushDownFilter::new()
3545 .rewrite(plan, &OptimizerContext::new())
3546 .expect("failed to optimize plan")
3547 .data;
3548 assert_optimized_plan_equal!(
3549 optimized_plan,
3550 @r"
3551 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3552 Projection: test.a, test.b, test.c
3553 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3554 Projection: test1.a AS d, test1.a AS e
3555 TableScan: test1
3556 "
3557 )
3558 }
3559
3560 #[test]
3561 fn left_semi_join() -> Result<()> {
3562 let left = test_table_scan_with_name("test1")?;
3563 let right_table_scan = test_table_scan_with_name("test2")?;
3564 let right = LogicalPlanBuilder::from(right_table_scan)
3565 .project(vec![col("a"), col("b")])?
3566 .build()?;
3567 let plan = LogicalPlanBuilder::from(left)
3568 .join(
3569 right,
3570 JoinType::LeftSemi,
3571 (
3572 vec![Column::from_qualified_name("test1.a")],
3573 vec![Column::from_qualified_name("test2.a")],
3574 ),
3575 None,
3576 )?
3577 .filter(col("test2.a").lt_eq(lit(1i64)))?
3578 .build()?;
3579
3580 assert_snapshot!(plan,
3582 @r"
3583 Filter: test2.a <= Int64(1)
3584 LeftSemi Join: test1.a = test2.a
3585 TableScan: test1
3586 Projection: test2.a, test2.b
3587 TableScan: test2
3588 ",
3589 );
3590 assert_optimized_plan_equal!(
3592 plan,
3593 @r"
3594 Filter: test2.a <= Int64(1)
3595 LeftSemi Join: test1.a = test2.a
3596 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3597 Projection: test2.a, test2.b
3598 TableScan: test2
3599 "
3600 )
3601 }
3602
3603 #[test]
3604 fn left_semi_join_with_filters() -> Result<()> {
3605 let left = test_table_scan_with_name("test1")?;
3606 let right_table_scan = test_table_scan_with_name("test2")?;
3607 let right = LogicalPlanBuilder::from(right_table_scan)
3608 .project(vec![col("a"), col("b")])?
3609 .build()?;
3610 let plan = LogicalPlanBuilder::from(left)
3611 .join(
3612 right,
3613 JoinType::LeftSemi,
3614 (
3615 vec![Column::from_qualified_name("test1.a")],
3616 vec![Column::from_qualified_name("test2.a")],
3617 ),
3618 Some(
3619 col("test1.b")
3620 .gt(lit(1u32))
3621 .and(col("test2.b").gt(lit(2u32))),
3622 ),
3623 )?
3624 .build()?;
3625
3626 assert_snapshot!(plan,
3628 @r"
3629 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3630 TableScan: test1
3631 Projection: test2.a, test2.b
3632 TableScan: test2
3633 ",
3634 );
3635 assert_optimized_plan_equal!(
3637 plan,
3638 @r"
3639 LeftSemi Join: test1.a = test2.a
3640 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3641 Projection: test2.a, test2.b
3642 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3643 "
3644 )
3645 }
3646
3647 #[test]
3648 fn right_semi_join() -> Result<()> {
3649 let left = test_table_scan_with_name("test1")?;
3650 let right_table_scan = test_table_scan_with_name("test2")?;
3651 let right = LogicalPlanBuilder::from(right_table_scan)
3652 .project(vec![col("a"), col("b")])?
3653 .build()?;
3654 let plan = LogicalPlanBuilder::from(left)
3655 .join(
3656 right,
3657 JoinType::RightSemi,
3658 (
3659 vec![Column::from_qualified_name("test1.a")],
3660 vec![Column::from_qualified_name("test2.a")],
3661 ),
3662 None,
3663 )?
3664 .filter(col("test1.a").lt_eq(lit(1i64)))?
3665 .build()?;
3666
3667 assert_snapshot!(plan,
3669 @r"
3670 Filter: test1.a <= Int64(1)
3671 RightSemi Join: test1.a = test2.a
3672 TableScan: test1
3673 Projection: test2.a, test2.b
3674 TableScan: test2
3675 ",
3676 );
3677 assert_optimized_plan_equal!(
3679 plan,
3680 @r"
3681 Filter: test1.a <= Int64(1)
3682 RightSemi Join: test1.a = test2.a
3683 TableScan: test1
3684 Projection: test2.a, test2.b
3685 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3686 "
3687 )
3688 }
3689
3690 #[test]
3691 fn right_semi_join_with_filters() -> Result<()> {
3692 let left = test_table_scan_with_name("test1")?;
3693 let right_table_scan = test_table_scan_with_name("test2")?;
3694 let right = LogicalPlanBuilder::from(right_table_scan)
3695 .project(vec![col("a"), col("b")])?
3696 .build()?;
3697 let plan = LogicalPlanBuilder::from(left)
3698 .join(
3699 right,
3700 JoinType::RightSemi,
3701 (
3702 vec![Column::from_qualified_name("test1.a")],
3703 vec![Column::from_qualified_name("test2.a")],
3704 ),
3705 Some(
3706 col("test1.b")
3707 .gt(lit(1u32))
3708 .and(col("test2.b").gt(lit(2u32))),
3709 ),
3710 )?
3711 .build()?;
3712
3713 assert_snapshot!(plan,
3715 @r"
3716 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3717 TableScan: test1
3718 Projection: test2.a, test2.b
3719 TableScan: test2
3720 ",
3721 );
3722 assert_optimized_plan_equal!(
3724 plan,
3725 @r"
3726 RightSemi Join: test1.a = test2.a
3727 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3728 Projection: test2.a, test2.b
3729 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3730 "
3731 )
3732 }
3733
3734 #[test]
3735 fn left_anti_join() -> Result<()> {
3736 let table_scan = test_table_scan_with_name("test1")?;
3737 let left = LogicalPlanBuilder::from(table_scan)
3738 .project(vec![col("a"), col("b")])?
3739 .build()?;
3740 let right_table_scan = test_table_scan_with_name("test2")?;
3741 let right = LogicalPlanBuilder::from(right_table_scan)
3742 .project(vec![col("a"), col("b")])?
3743 .build()?;
3744 let plan = LogicalPlanBuilder::from(left)
3745 .join(
3746 right,
3747 JoinType::LeftAnti,
3748 (
3749 vec![Column::from_qualified_name("test1.a")],
3750 vec![Column::from_qualified_name("test2.a")],
3751 ),
3752 None,
3753 )?
3754 .filter(col("test2.a").gt(lit(2u32)))?
3755 .build()?;
3756
3757 assert_snapshot!(plan,
3759 @r"
3760 Filter: test2.a > UInt32(2)
3761 LeftAnti Join: test1.a = test2.a
3762 Projection: test1.a, test1.b
3763 TableScan: test1
3764 Projection: test2.a, test2.b
3765 TableScan: test2
3766 ",
3767 );
3768 assert_optimized_plan_equal!(
3770 plan,
3771 @r"
3772 Filter: test2.a > UInt32(2)
3773 LeftAnti Join: test1.a = test2.a
3774 Projection: test1.a, test1.b
3775 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3776 Projection: test2.a, test2.b
3777 TableScan: test2
3778 "
3779 )
3780 }
3781
3782 #[test]
3783 fn left_anti_join_with_filters() -> Result<()> {
3784 let table_scan = test_table_scan_with_name("test1")?;
3785 let left = LogicalPlanBuilder::from(table_scan)
3786 .project(vec![col("a"), col("b")])?
3787 .build()?;
3788 let right_table_scan = test_table_scan_with_name("test2")?;
3789 let right = LogicalPlanBuilder::from(right_table_scan)
3790 .project(vec![col("a"), col("b")])?
3791 .build()?;
3792 let plan = LogicalPlanBuilder::from(left)
3793 .join(
3794 right,
3795 JoinType::LeftAnti,
3796 (
3797 vec![Column::from_qualified_name("test1.a")],
3798 vec![Column::from_qualified_name("test2.a")],
3799 ),
3800 Some(
3801 col("test1.b")
3802 .gt(lit(1u32))
3803 .and(col("test2.b").gt(lit(2u32))),
3804 ),
3805 )?
3806 .build()?;
3807
3808 assert_snapshot!(plan,
3810 @r"
3811 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3812 Projection: test1.a, test1.b
3813 TableScan: test1
3814 Projection: test2.a, test2.b
3815 TableScan: test2
3816 ",
3817 );
3818 assert_optimized_plan_equal!(
3820 plan,
3821 @r"
3822 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3823 Projection: test1.a, test1.b
3824 TableScan: test1
3825 Projection: test2.a, test2.b
3826 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3827 "
3828 )
3829 }
3830
3831 #[test]
3832 fn right_anti_join() -> Result<()> {
3833 let table_scan = test_table_scan_with_name("test1")?;
3834 let left = LogicalPlanBuilder::from(table_scan)
3835 .project(vec![col("a"), col("b")])?
3836 .build()?;
3837 let right_table_scan = test_table_scan_with_name("test2")?;
3838 let right = LogicalPlanBuilder::from(right_table_scan)
3839 .project(vec![col("a"), col("b")])?
3840 .build()?;
3841 let plan = LogicalPlanBuilder::from(left)
3842 .join(
3843 right,
3844 JoinType::RightAnti,
3845 (
3846 vec![Column::from_qualified_name("test1.a")],
3847 vec![Column::from_qualified_name("test2.a")],
3848 ),
3849 None,
3850 )?
3851 .filter(col("test1.a").gt(lit(2u32)))?
3852 .build()?;
3853
3854 assert_snapshot!(plan,
3856 @r"
3857 Filter: test1.a > UInt32(2)
3858 RightAnti Join: test1.a = test2.a
3859 Projection: test1.a, test1.b
3860 TableScan: test1
3861 Projection: test2.a, test2.b
3862 TableScan: test2
3863 ",
3864 );
3865 assert_optimized_plan_equal!(
3867 plan,
3868 @r"
3869 Filter: test1.a > UInt32(2)
3870 RightAnti Join: test1.a = test2.a
3871 Projection: test1.a, test1.b
3872 TableScan: test1
3873 Projection: test2.a, test2.b
3874 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3875 "
3876 )
3877 }
3878
3879 #[test]
3880 fn right_anti_join_with_filters() -> Result<()> {
3881 let table_scan = test_table_scan_with_name("test1")?;
3882 let left = LogicalPlanBuilder::from(table_scan)
3883 .project(vec![col("a"), col("b")])?
3884 .build()?;
3885 let right_table_scan = test_table_scan_with_name("test2")?;
3886 let right = LogicalPlanBuilder::from(right_table_scan)
3887 .project(vec![col("a"), col("b")])?
3888 .build()?;
3889 let plan = LogicalPlanBuilder::from(left)
3890 .join(
3891 right,
3892 JoinType::RightAnti,
3893 (
3894 vec![Column::from_qualified_name("test1.a")],
3895 vec![Column::from_qualified_name("test2.a")],
3896 ),
3897 Some(
3898 col("test1.b")
3899 .gt(lit(1u32))
3900 .and(col("test2.b").gt(lit(2u32))),
3901 ),
3902 )?
3903 .build()?;
3904
3905 assert_snapshot!(plan,
3907 @r"
3908 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3909 Projection: test1.a, test1.b
3910 TableScan: test1
3911 Projection: test2.a, test2.b
3912 TableScan: test2
3913 ",
3914 );
3915 assert_optimized_plan_equal!(
3917 plan,
3918 @r"
3919 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3920 Projection: test1.a, test1.b
3921 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3922 Projection: test2.a, test2.b
3923 TableScan: test2
3924 "
3925 )
3926 }
3927
3928 #[derive(Debug, PartialEq, Eq, Hash)]
3929 struct TestScalarUDF {
3930 signature: Signature,
3931 }
3932
3933 impl ScalarUDFImpl for TestScalarUDF {
3934 fn as_any(&self) -> &dyn Any {
3935 self
3936 }
3937 fn name(&self) -> &str {
3938 "TestScalarUDF"
3939 }
3940
3941 fn signature(&self) -> &Signature {
3942 &self.signature
3943 }
3944
3945 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3946 Ok(DataType::Int32)
3947 }
3948
3949 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3950 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3951 }
3952 }
3953
3954 #[test]
3955 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3956 let table_scan = test_table_scan_with_name("test1")?;
3958 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3959 signature: Signature::exact(vec![], Volatility::Volatile),
3960 });
3961 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3962
3963 let plan = LogicalPlanBuilder::from(table_scan)
3964 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3965 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3966 .alias("t")?
3967 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3968 .project(vec![col("t.a"), col("t.r")])?
3969 .build()?;
3970
3971 assert_snapshot!(plan,
3972 @r"
3973 Projection: t.a, t.r
3974 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3975 SubqueryAlias: t
3976 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3977 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3978 TableScan: test1
3979 ",
3980 );
3981 assert_optimized_plan_equal!(
3982 plan,
3983 @r"
3984 Projection: t.a, t.r
3985 SubqueryAlias: t
3986 Filter: r > Float64(0.5)
3987 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3988 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3989 TableScan: test1, full_filters=[test1.a > Int32(5)]
3990 "
3991 )
3992 }
3993
3994 #[test]
3995 fn test_push_down_volatile_function_in_join() -> Result<()> {
3996 let table_scan = test_table_scan_with_name("test1")?;
3998 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3999 signature: Signature::exact(vec![], Volatility::Volatile),
4000 });
4001 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4002 let left = LogicalPlanBuilder::from(table_scan).build()?;
4003 let right_table_scan = test_table_scan_with_name("test2")?;
4004 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4005 let plan = LogicalPlanBuilder::from(left)
4006 .join(
4007 right,
4008 JoinType::Inner,
4009 (
4010 vec![Column::from_qualified_name("test1.a")],
4011 vec![Column::from_qualified_name("test2.a")],
4012 ),
4013 None,
4014 )?
4015 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4016 .alias("t")?
4017 .filter(col("t.r").gt(lit(0.8)))?
4018 .project(vec![col("t.a"), col("t.r")])?
4019 .build()?;
4020
4021 assert_snapshot!(plan,
4022 @r"
4023 Projection: t.a, t.r
4024 Filter: t.r > Float64(0.8)
4025 SubqueryAlias: t
4026 Projection: test1.a AS a, TestScalarUDF() AS r
4027 Inner Join: test1.a = test2.a
4028 TableScan: test1
4029 TableScan: test2
4030 ",
4031 );
4032 assert_optimized_plan_equal!(
4033 plan,
4034 @r"
4035 Projection: t.a, t.r
4036 SubqueryAlias: t
4037 Filter: r > Float64(0.8)
4038 Projection: test1.a AS a, TestScalarUDF() AS r
4039 Inner Join: test1.a = test2.a
4040 TableScan: test1
4041 TableScan: test2
4042 "
4043 )
4044 }
4045
4046 #[test]
4047 fn test_push_down_volatile_table_scan() -> Result<()> {
4048 let table_scan = test_table_scan()?;
4050 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4051 signature: Signature::exact(vec![], Volatility::Volatile),
4052 });
4053 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4054 let plan = LogicalPlanBuilder::from(table_scan)
4055 .project(vec![col("a"), col("b")])?
4056 .filter(expr.gt(lit(0.1)))?
4057 .build()?;
4058
4059 assert_snapshot!(plan,
4060 @r"
4061 Filter: TestScalarUDF() > Float64(0.1)
4062 Projection: test.a, test.b
4063 TableScan: test
4064 ",
4065 );
4066 assert_optimized_plan_equal!(
4067 plan,
4068 @r"
4069 Projection: test.a, test.b
4070 Filter: TestScalarUDF() > Float64(0.1)
4071 TableScan: test
4072 "
4073 )
4074 }
4075
4076 #[test]
4077 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4078 let table_scan = test_table_scan()?;
4080 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4081 signature: Signature::exact(vec![], Volatility::Volatile),
4082 });
4083 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4084 let plan = LogicalPlanBuilder::from(table_scan)
4085 .project(vec![col("a"), col("b")])?
4086 .filter(
4087 expr.gt(lit(0.1))
4088 .and(col("t.a").gt(lit(5)))
4089 .and(col("t.b").gt(lit(10))),
4090 )?
4091 .build()?;
4092
4093 assert_snapshot!(plan,
4094 @r"
4095 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4096 Projection: test.a, test.b
4097 TableScan: test
4098 ",
4099 );
4100 assert_optimized_plan_equal!(
4101 plan,
4102 @r"
4103 Projection: test.a, test.b
4104 Filter: TestScalarUDF() > Float64(0.1)
4105 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4106 "
4107 )
4108 }
4109
4110 #[test]
4111 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4112 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4114 signature: Signature::exact(vec![], Volatility::Volatile),
4115 });
4116 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4117 let plan = table_scan_with_pushdown_provider_builder(
4118 TableProviderFilterPushDown::Unsupported,
4119 vec![],
4120 None,
4121 )?
4122 .project(vec![col("a"), col("b")])?
4123 .filter(
4124 expr.gt(lit(0.1))
4125 .and(col("t.a").gt(lit(5)))
4126 .and(col("t.b").gt(lit(10))),
4127 )?
4128 .build()?;
4129
4130 assert_snapshot!(plan,
4131 @r"
4132 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4133 Projection: a, b
4134 TableScan: test
4135 ",
4136 );
4137 assert_optimized_plan_equal!(
4138 plan,
4139 @r"
4140 Projection: a, b
4141 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4142 TableScan: test
4143 "
4144 )
4145 }
4146
4147 #[test]
4148 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4149 #[derive(Debug, Hash, Eq, PartialEq)]
4151 struct TestUserNode {
4152 schema: DFSchemaRef,
4153 }
4154
4155 impl PartialOrd for TestUserNode {
4156 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4157 None
4158 }
4159 }
4160
4161 impl TestUserNode {
4162 fn new() -> Self {
4163 let schema = Arc::new(
4164 DFSchema::new_with_metadata(
4165 vec![(None, Field::new("a", DataType::Int64, false).into())],
4166 Default::default(),
4167 )
4168 .unwrap(),
4169 );
4170
4171 Self { schema }
4172 }
4173 }
4174
4175 impl UserDefinedLogicalNodeCore for TestUserNode {
4176 fn name(&self) -> &str {
4177 "test_node"
4178 }
4179
4180 fn inputs(&self) -> Vec<&LogicalPlan> {
4181 vec![]
4182 }
4183
4184 fn schema(&self) -> &DFSchemaRef {
4185 &self.schema
4186 }
4187
4188 fn expressions(&self) -> Vec<Expr> {
4189 vec![]
4190 }
4191
4192 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4193 write!(f, "TestUserNode")
4194 }
4195
4196 fn with_exprs_and_inputs(
4197 &self,
4198 exprs: Vec<Expr>,
4199 inputs: Vec<LogicalPlan>,
4200 ) -> Result<Self> {
4201 assert!(exprs.is_empty());
4202 assert!(inputs.is_empty());
4203 Ok(Self {
4204 schema: Arc::clone(&self.schema),
4205 })
4206 }
4207 }
4208
4209 let node = LogicalPlan::Extension(Extension {
4211 node: Arc::new(TestUserNode::new()),
4212 });
4213
4214 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4215
4216 assert_snapshot!(plan,
4218 @r"
4219 Filter: Boolean(false)
4220 TestUserNode
4221 ",
4222 );
4223 assert_optimized_plan_equal!(
4225 plan,
4226 @r"
4227 Filter: Boolean(false)
4228 TestUserNode
4229 "
4230 )
4231 }
4232
4233 #[test]
4234 fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> {
4235 let scan = test_table_scan()?;
4236 let scan_with_fetch = match scan {
4237 LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan {
4238 fetch: Some(10),
4239 ..scan
4240 }),
4241 _ => unreachable!(),
4242 };
4243 let plan = LogicalPlanBuilder::from(scan_with_fetch)
4244 .filter(col("a").gt(lit(10i64)))?
4245 .build()?;
4246 assert_optimized_plan_equal!(
4248 plan,
4249 @r"
4250 Filter: test.a > Int64(10)
4251 TableScan: test, fetch=10
4252 "
4253 )
4254 }
4255
4256 #[test]
4257 fn filter_push_down_through_sort_without_fetch() -> Result<()> {
4258 let table_scan = test_table_scan()?;
4259 let plan = LogicalPlanBuilder::from(table_scan)
4260 .sort(vec![col("a").sort(true, true)])?
4261 .filter(col("a").gt(lit(10i64)))?
4262 .build()?;
4263 assert_optimized_plan_equal!(
4265 plan,
4266 @r"
4267 Sort: test.a ASC NULLS FIRST
4268 TableScan: test, full_filters=[test.a > Int64(10)]
4269 "
4270 )
4271 }
4272
4273 #[test]
4274 fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> {
4275 let table_scan = test_table_scan()?;
4276 let plan = LogicalPlanBuilder::from(table_scan)
4277 .sort_with_limit(vec![col("a").sort(true, true)], Some(5))?
4278 .filter(col("a").gt(lit(10i64)))?
4279 .build()?;
4280 assert_optimized_plan_equal!(
4283 plan,
4284 @r"
4285 Filter: test.a > Int64(10)
4286 Sort: test.a ASC NULLS FIRST, fetch=5
4287 TableScan: test
4288 "
4289 )
4290 }
4291}