1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use indexmap::IndexSet;
24use itertools::Itertools;
25
26use datafusion_common::tree_node::{
27 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
28};
29use datafusion_common::{
30 internal_err, plan_err, qualified_name, Column, DFSchema, Result,
31};
32use datafusion_expr::expr::WindowFunction;
33use datafusion_expr::expr_rewriter::replace_col;
34use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
35use datafusion_expr::utils::{
36 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
37};
38use datafusion_expr::{
39 and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
40};
41
42use crate::optimizer::ApplyOrder;
43use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
44use crate::{OptimizerConfig, OptimizerRule};
45
46#[derive(Default, Debug)]
134pub struct PushDownFilter {}
135
136pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
161 match join_type {
162 JoinType::Inner => (true, true),
163 JoinType::Left => (true, false),
164 JoinType::Right => (false, true),
165 JoinType::Full => (false, false),
166 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
169 JoinType::RightSemi | JoinType::RightAnti => (false, true),
172 }
173}
174
175pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
185 match join_type {
186 JoinType::Inner => (true, true),
187 JoinType::Left => (false, true),
188 JoinType::Right => (true, false),
189 JoinType::Full => (false, false),
190 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
191 JoinType::LeftAnti => (false, true),
192 JoinType::RightAnti => (true, false),
193 JoinType::LeftMark => (false, true),
194 }
195}
196
197#[derive(Debug)]
200struct ColumnChecker<'a> {
201 left_schema: &'a DFSchema,
203 left_columns: Option<HashSet<Column>>,
205 right_schema: &'a DFSchema,
207 right_columns: Option<HashSet<Column>>,
209}
210
211impl<'a> ColumnChecker<'a> {
212 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
213 Self {
214 left_schema,
215 left_columns: None,
216 right_schema,
217 right_columns: None,
218 }
219 }
220
221 fn is_left_only(&mut self, predicate: &Expr) -> bool {
223 if self.left_columns.is_none() {
224 self.left_columns = Some(schema_columns(self.left_schema));
225 }
226 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
227 }
228
229 fn is_right_only(&mut self, predicate: &Expr) -> bool {
231 if self.right_columns.is_none() {
232 self.right_columns = Some(schema_columns(self.right_schema));
233 }
234 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
235 }
236}
237
238fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
240 schema
241 .iter()
242 .flat_map(|(qualifier, field)| {
243 [
244 Column::new(qualifier.cloned(), field.name()),
245 Column::new_unqualified(field.name()),
247 ]
248 })
249 .collect::<HashSet<_>>()
250}
251
252fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
254 let mut is_evaluate = true;
255 predicate.apply(|expr| match expr {
256 Expr::Column(_)
257 | Expr::Literal(_, _)
258 | Expr::Placeholder(_)
259 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
260 Expr::Exists { .. }
261 | Expr::InSubquery(_)
262 | Expr::ScalarSubquery(_)
263 | Expr::OuterReferenceColumn(_, _)
264 | Expr::Unnest(_) => {
265 is_evaluate = false;
266 Ok(TreeNodeRecursion::Stop)
267 }
268 Expr::Alias(_)
269 | Expr::BinaryExpr(_)
270 | Expr::Like(_)
271 | Expr::SimilarTo(_)
272 | Expr::Not(_)
273 | Expr::IsNotNull(_)
274 | Expr::IsNull(_)
275 | Expr::IsTrue(_)
276 | Expr::IsFalse(_)
277 | Expr::IsUnknown(_)
278 | Expr::IsNotTrue(_)
279 | Expr::IsNotFalse(_)
280 | Expr::IsNotUnknown(_)
281 | Expr::Negative(_)
282 | Expr::Between(_)
283 | Expr::Case(_)
284 | Expr::Cast(_)
285 | Expr::TryCast(_)
286 | Expr::InList { .. }
287 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
288 #[expect(deprecated)]
290 Expr::AggregateFunction(_)
291 | Expr::WindowFunction(_)
292 | Expr::Wildcard { .. }
293 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
294 })?;
295 Ok(is_evaluate)
296}
297
298fn extract_or_clauses_for_join<'a>(
332 filters: &'a [Expr],
333 schema: &'a DFSchema,
334) -> impl Iterator<Item = Expr> + 'a {
335 let schema_columns = schema_columns(schema);
336
337 filters.iter().filter_map(move |expr| {
339 if let Expr::BinaryExpr(BinaryExpr {
340 left,
341 op: Operator::Or,
342 right,
343 }) = expr
344 {
345 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
346 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
347
348 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
350 return Some(or(left_expr, right_expr));
351 }
352 }
353 None
354 })
355}
356
357fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
369 let mut predicate = None;
370
371 match expr {
372 Expr::BinaryExpr(BinaryExpr {
373 left: l_expr,
374 op: Operator::Or,
375 right: r_expr,
376 }) => {
377 let l_expr = extract_or_clause(l_expr, schema_columns);
378 let r_expr = extract_or_clause(r_expr, schema_columns);
379
380 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
381 predicate = Some(or(l_expr, r_expr));
382 }
383 }
384 Expr::BinaryExpr(BinaryExpr {
385 left: l_expr,
386 op: Operator::And,
387 right: r_expr,
388 }) => {
389 let l_expr = extract_or_clause(l_expr, schema_columns);
390 let r_expr = extract_or_clause(r_expr, schema_columns);
391
392 match (l_expr, r_expr) {
393 (Some(l_expr), Some(r_expr)) => {
394 predicate = Some(and(l_expr, r_expr));
395 }
396 (Some(l_expr), None) => {
397 predicate = Some(l_expr);
398 }
399 (None, Some(r_expr)) => {
400 predicate = Some(r_expr);
401 }
402 (None, None) => {
403 predicate = None;
404 }
405 }
406 }
407 _ => {
408 if has_all_column_refs(expr, schema_columns) {
409 predicate = Some(expr.clone());
410 }
411 }
412 }
413
414 predicate
415}
416
417fn push_down_all_join(
419 predicates: Vec<Expr>,
420 inferred_join_predicates: Vec<Expr>,
421 mut join: Join,
422 on_filter: Vec<Expr>,
423) -> Result<Transformed<LogicalPlan>> {
424 let is_inner_join = join.join_type == JoinType::Inner;
425 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
427
428 let left_schema = join.left.schema();
433 let right_schema = join.right.schema();
434 let mut left_push = vec![];
435 let mut right_push = vec![];
436 let mut keep_predicates = vec![];
437 let mut join_conditions = vec![];
438 let mut checker = ColumnChecker::new(left_schema, right_schema);
439 for predicate in predicates {
440 if left_preserved && checker.is_left_only(&predicate) {
441 left_push.push(predicate);
442 } else if right_preserved && checker.is_right_only(&predicate) {
443 right_push.push(predicate);
444 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
445 join_conditions.push(predicate);
448 } else {
449 keep_predicates.push(predicate);
450 }
451 }
452
453 for predicate in inferred_join_predicates {
455 if left_preserved && checker.is_left_only(&predicate) {
456 left_push.push(predicate);
457 } else if right_preserved && checker.is_right_only(&predicate) {
458 right_push.push(predicate);
459 }
460 }
461
462 let mut on_filter_join_conditions = vec![];
463 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
464
465 if !on_filter.is_empty() {
466 for on in on_filter {
467 if on_left_preserved && checker.is_left_only(&on) {
468 left_push.push(on)
469 } else if on_right_preserved && checker.is_right_only(&on) {
470 right_push.push(on)
471 } else {
472 on_filter_join_conditions.push(on)
473 }
474 }
475 }
476
477 if left_preserved {
480 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
481 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
482 }
483 if right_preserved {
484 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
485 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
486 }
487
488 if on_left_preserved {
491 left_push.extend(extract_or_clauses_for_join(
492 &on_filter_join_conditions,
493 left_schema,
494 ));
495 }
496 if on_right_preserved {
497 right_push.extend(extract_or_clauses_for_join(
498 &on_filter_join_conditions,
499 right_schema,
500 ));
501 }
502
503 if let Some(predicate) = conjunction(left_push) {
504 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
505 }
506 if let Some(predicate) = conjunction(right_push) {
507 join.right =
508 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
509 }
510
511 join_conditions.extend(on_filter_join_conditions);
513 join.filter = conjunction(join_conditions);
514
515 let plan = LogicalPlan::Join(join);
517 let plan = if let Some(predicate) = conjunction(keep_predicates) {
518 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
519 } else {
520 plan
521 };
522 Ok(Transformed::yes(plan))
523}
524
525fn push_down_join(
526 join: Join,
527 parent_predicate: Option<&Expr>,
528) -> Result<Transformed<LogicalPlan>> {
529 let predicates = parent_predicate
531 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
532
533 let on_filters = join
535 .filter
536 .as_ref()
537 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
538
539 let inferred_join_predicates =
541 infer_join_predicates(&join, &predicates, &on_filters)?;
542
543 if on_filters.is_empty()
544 && predicates.is_empty()
545 && inferred_join_predicates.is_empty()
546 {
547 return Ok(Transformed::no(LogicalPlan::Join(join)));
548 }
549
550 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
551}
552
553fn infer_join_predicates(
564 join: &Join,
565 predicates: &[Expr],
566 on_filters: &[Expr],
567) -> Result<Vec<Expr>> {
568 let join_col_keys = join
570 .on
571 .iter()
572 .filter_map(|(l, r)| {
573 let left_col = l.try_as_col()?;
574 let right_col = r.try_as_col()?;
575 Some((left_col, right_col))
576 })
577 .collect::<Vec<_>>();
578
579 let join_type = join.join_type;
580
581 let mut inferred_predicates = InferredPredicates::new(join_type);
582
583 infer_join_predicates_from_predicates(
584 &join_col_keys,
585 predicates,
586 &mut inferred_predicates,
587 )?;
588
589 infer_join_predicates_from_on_filters(
590 &join_col_keys,
591 join_type,
592 on_filters,
593 &mut inferred_predicates,
594 )?;
595
596 Ok(inferred_predicates.predicates)
597}
598
599struct InferredPredicates {
608 predicates: Vec<Expr>,
609 is_inner_join: bool,
610}
611
612impl InferredPredicates {
613 fn new(join_type: JoinType) -> Self {
614 Self {
615 predicates: vec![],
616 is_inner_join: matches!(join_type, JoinType::Inner),
617 }
618 }
619
620 fn try_build_predicate(
621 &mut self,
622 predicate: Expr,
623 replace_map: &HashMap<&Column, &Column>,
624 ) -> Result<()> {
625 if self.is_inner_join
626 || matches!(
627 is_restrict_null_predicate(
628 predicate.clone(),
629 replace_map.keys().cloned()
630 ),
631 Ok(true)
632 )
633 {
634 self.predicates.push(replace_col(predicate, replace_map)?);
635 }
636
637 Ok(())
638 }
639}
640
641fn infer_join_predicates_from_predicates(
651 join_col_keys: &[(&Column, &Column)],
652 predicates: &[Expr],
653 inferred_predicates: &mut InferredPredicates,
654) -> Result<()> {
655 infer_join_predicates_impl::<true, true>(
656 join_col_keys,
657 predicates,
658 inferred_predicates,
659 )
660}
661
662fn infer_join_predicates_from_on_filters(
675 join_col_keys: &[(&Column, &Column)],
676 join_type: JoinType,
677 on_filters: &[Expr],
678 inferred_predicates: &mut InferredPredicates,
679) -> Result<()> {
680 match join_type {
681 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
682 JoinType::Inner => infer_join_predicates_impl::<true, true>(
683 join_col_keys,
684 on_filters,
685 inferred_predicates,
686 ),
687 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
688 infer_join_predicates_impl::<true, false>(
689 join_col_keys,
690 on_filters,
691 inferred_predicates,
692 )
693 }
694 JoinType::Right | JoinType::RightSemi => {
695 infer_join_predicates_impl::<false, true>(
696 join_col_keys,
697 on_filters,
698 inferred_predicates,
699 )
700 }
701 }
702}
703
704fn infer_join_predicates_impl<
721 const ENABLE_LEFT_TO_RIGHT: bool,
722 const ENABLE_RIGHT_TO_LEFT: bool,
723>(
724 join_col_keys: &[(&Column, &Column)],
725 input_predicates: &[Expr],
726 inferred_predicates: &mut InferredPredicates,
727) -> Result<()> {
728 for predicate in input_predicates {
729 let mut join_cols_to_replace = HashMap::new();
730
731 for &col in &predicate.column_refs() {
732 for (l, r) in join_col_keys.iter() {
733 if ENABLE_LEFT_TO_RIGHT && col == *l {
734 join_cols_to_replace.insert(col, *r);
735 break;
736 }
737 if ENABLE_RIGHT_TO_LEFT && col == *r {
738 join_cols_to_replace.insert(col, *l);
739 break;
740 }
741 }
742 }
743 if join_cols_to_replace.is_empty() {
744 continue;
745 }
746
747 inferred_predicates
748 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
749 }
750 Ok(())
751}
752
753impl OptimizerRule for PushDownFilter {
754 fn name(&self) -> &str {
755 "push_down_filter"
756 }
757
758 fn apply_order(&self) -> Option<ApplyOrder> {
759 Some(ApplyOrder::TopDown)
760 }
761
762 fn supports_rewrite(&self) -> bool {
763 true
764 }
765
766 fn rewrite(
767 &self,
768 plan: LogicalPlan,
769 _config: &dyn OptimizerConfig,
770 ) -> Result<Transformed<LogicalPlan>> {
771 if let LogicalPlan::Join(join) = plan {
772 return push_down_join(join, None);
773 };
774
775 let plan_schema = Arc::clone(plan.schema());
776
777 let LogicalPlan::Filter(mut filter) = plan else {
778 return Ok(Transformed::no(plan));
779 };
780
781 match Arc::unwrap_or_clone(filter.input) {
782 LogicalPlan::Filter(child_filter) => {
783 let parents_predicates = split_conjunction_owned(filter.predicate);
784
785 let child_predicates = split_conjunction_owned(child_filter.predicate);
787 let new_predicates = parents_predicates
788 .into_iter()
789 .chain(child_predicates)
790 .collect::<IndexSet<_>>()
792 .into_iter()
793 .collect::<Vec<_>>();
794
795 let Some(new_predicate) = conjunction(new_predicates) else {
796 return plan_err!("at least one expression exists");
797 };
798 let new_filter = LogicalPlan::Filter(Filter::try_new(
799 new_predicate,
800 child_filter.input,
801 )?);
802 #[allow(clippy::used_underscore_binding)]
803 self.rewrite(new_filter, _config)
804 }
805 LogicalPlan::Repartition(repartition) => {
806 let new_filter =
807 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
808 .map(LogicalPlan::Filter)?;
809 insert_below(LogicalPlan::Repartition(repartition), new_filter)
810 }
811 LogicalPlan::Distinct(distinct) => {
812 let new_filter =
813 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
814 .map(LogicalPlan::Filter)?;
815 insert_below(LogicalPlan::Distinct(distinct), new_filter)
816 }
817 LogicalPlan::Sort(sort) => {
818 let new_filter =
819 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
820 .map(LogicalPlan::Filter)?;
821 insert_below(LogicalPlan::Sort(sort), new_filter)
822 }
823 LogicalPlan::SubqueryAlias(subquery_alias) => {
824 let mut replace_map = HashMap::new();
825 for (i, (qualifier, field)) in
826 subquery_alias.input.schema().iter().enumerate()
827 {
828 let (sub_qualifier, sub_field) =
829 subquery_alias.schema.qualified_field(i);
830 replace_map.insert(
831 qualified_name(sub_qualifier, sub_field.name()),
832 Expr::Column(Column::new(qualifier.cloned(), field.name())),
833 );
834 }
835 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
836
837 let new_filter = LogicalPlan::Filter(Filter::try_new(
838 new_predicate,
839 Arc::clone(&subquery_alias.input),
840 )?);
841 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
842 }
843 LogicalPlan::Projection(projection) => {
844 let predicates = split_conjunction_owned(filter.predicate.clone());
845 let (new_projection, keep_predicate) =
846 rewrite_projection(predicates, projection)?;
847 if new_projection.transformed {
848 match keep_predicate {
849 None => Ok(new_projection),
850 Some(keep_predicate) => new_projection.map_data(|child_plan| {
851 Filter::try_new(keep_predicate, Arc::new(child_plan))
852 .map(LogicalPlan::Filter)
853 }),
854 }
855 } else {
856 filter.input = Arc::new(new_projection.data);
857 Ok(Transformed::no(LogicalPlan::Filter(filter)))
858 }
859 }
860 LogicalPlan::Unnest(mut unnest) => {
861 let predicates = split_conjunction_owned(filter.predicate.clone());
862 let mut non_unnest_predicates = vec![];
863 let mut unnest_predicates = vec![];
864 for predicate in predicates {
865 let mut accum: HashSet<Column> = HashSet::new();
867 expr_to_columns(&predicate, &mut accum)?;
868
869 if unnest.list_type_columns.iter().any(|(_, unnest_list)| {
870 accum.contains(&unnest_list.output_column)
871 }) {
872 unnest_predicates.push(predicate);
873 } else {
874 non_unnest_predicates.push(predicate);
875 }
876 }
877
878 if non_unnest_predicates.is_empty() {
881 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
882 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
883 }
884
885 let unnest_input = std::mem::take(&mut unnest.input);
894
895 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
896 conjunction(non_unnest_predicates).unwrap(), unnest_input,
898 )?);
899
900 let unnest_plan =
904 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
905
906 match conjunction(unnest_predicates) {
907 None => Ok(unnest_plan),
908 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
909 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
910 ))),
911 }
912 }
913 LogicalPlan::Union(ref union) => {
914 let mut inputs = Vec::with_capacity(union.inputs.len());
915 for input in &union.inputs {
916 let mut replace_map = HashMap::new();
917 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
918 let (union_qualifier, union_field) =
919 union.schema.qualified_field(i);
920 replace_map.insert(
921 qualified_name(union_qualifier, union_field.name()),
922 Expr::Column(Column::new(qualifier.cloned(), field.name())),
923 );
924 }
925
926 let push_predicate =
927 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
928 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
929 push_predicate,
930 Arc::clone(input),
931 )?)))
932 }
933 Ok(Transformed::yes(LogicalPlan::Union(Union {
934 inputs,
935 schema: Arc::clone(&plan_schema),
936 })))
937 }
938 LogicalPlan::Aggregate(agg) => {
939 let group_expr_columns = agg
941 .group_expr
942 .iter()
943 .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
944 .collect::<Result<HashSet<_>>>()?;
945
946 let predicates = split_conjunction_owned(filter.predicate);
947
948 let mut keep_predicates = vec![];
949 let mut push_predicates = vec![];
950 for expr in predicates {
951 let cols = expr.column_refs();
952 if cols.iter().all(|c| group_expr_columns.contains(c)) {
953 push_predicates.push(expr);
954 } else {
955 keep_predicates.push(expr);
956 }
957 }
958
959 let mut replace_map = HashMap::new();
963 for expr in &agg.group_expr {
964 replace_map.insert(expr.schema_name().to_string(), expr.clone());
965 }
966 let replaced_push_predicates = push_predicates
967 .into_iter()
968 .map(|expr| replace_cols_by_name(expr, &replace_map))
969 .collect::<Result<Vec<_>>>()?;
970
971 let agg_input = Arc::clone(&agg.input);
972 Transformed::yes(LogicalPlan::Aggregate(agg))
973 .transform_data(|new_plan| {
974 if let Some(predicate) = conjunction(replaced_push_predicates) {
976 let new_filter = make_filter(predicate, agg_input)?;
977 insert_below(new_plan, new_filter)
978 } else {
979 Ok(Transformed::no(new_plan))
980 }
981 })?
982 .map_data(|child_plan| {
983 if let Some(predicate) = conjunction(keep_predicates) {
986 make_filter(predicate, Arc::new(child_plan))
987 } else {
988 Ok(child_plan)
989 }
990 })
991 }
992 LogicalPlan::Window(window) => {
1003 let extract_partition_keys = |func: &WindowFunction| {
1009 func.params
1010 .partition_by
1011 .iter()
1012 .map(|c| Column::from_qualified_name(c.schema_name().to_string()))
1013 .collect::<HashSet<_>>()
1014 };
1015 let potential_partition_keys = window
1016 .window_expr
1017 .iter()
1018 .map(|e| {
1019 match e {
1020 Expr::WindowFunction(window_func) => {
1021 extract_partition_keys(window_func)
1022 }
1023 Expr::Alias(alias) => {
1024 if let Expr::WindowFunction(window_func) =
1025 alias.expr.as_ref()
1026 {
1027 extract_partition_keys(window_func)
1028 } else {
1029 unreachable!()
1031 }
1032 }
1033 _ => {
1034 unreachable!()
1036 }
1037 }
1038 })
1039 .reduce(|a, b| &a & &b)
1042 .unwrap_or_default();
1043
1044 let predicates = split_conjunction_owned(filter.predicate);
1045 let mut keep_predicates = vec![];
1046 let mut push_predicates = vec![];
1047 for expr in predicates {
1048 let cols = expr.column_refs();
1049 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1050 push_predicates.push(expr);
1051 } else {
1052 keep_predicates.push(expr);
1053 }
1054 }
1055
1056 let window_input = Arc::clone(&window.input);
1065 Transformed::yes(LogicalPlan::Window(window))
1066 .transform_data(|new_plan| {
1067 if let Some(predicate) = conjunction(push_predicates) {
1069 let new_filter = make_filter(predicate, window_input)?;
1070 insert_below(new_plan, new_filter)
1071 } else {
1072 Ok(Transformed::no(new_plan))
1073 }
1074 })?
1075 .map_data(|child_plan| {
1076 if let Some(predicate) = conjunction(keep_predicates) {
1079 make_filter(predicate, Arc::new(child_plan))
1080 } else {
1081 Ok(child_plan)
1082 }
1083 })
1084 }
1085 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1086 LogicalPlan::TableScan(scan) => {
1087 let filter_predicates = split_conjunction(&filter.predicate);
1088
1089 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1090 filter_predicates
1091 .into_iter()
1092 .partition(|pred| pred.is_volatile());
1093
1094 let supported_filters = scan
1096 .source
1097 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1098 if non_volatile_filters.len() != supported_filters.len() {
1099 return internal_err!(
1100 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1101 supported_filters.len(),
1102 non_volatile_filters.len());
1103 }
1104
1105 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1107
1108 let new_scan_filters = zip
1109 .clone()
1110 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1111 .map(|(pred, _)| pred);
1112
1113 let new_scan_filters: Vec<Expr> = scan
1115 .filters
1116 .iter()
1117 .chain(new_scan_filters)
1118 .unique()
1119 .cloned()
1120 .collect();
1121
1122 let new_predicate: Vec<Expr> = zip
1124 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1125 .map(|(pred, _)| pred)
1126 .chain(volatile_filters)
1127 .cloned()
1128 .collect();
1129
1130 let new_scan = LogicalPlan::TableScan(TableScan {
1131 filters: new_scan_filters,
1132 ..scan
1133 });
1134
1135 Transformed::yes(new_scan).transform_data(|new_scan| {
1136 if let Some(predicate) = conjunction(new_predicate) {
1137 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1138 } else {
1139 Ok(Transformed::no(new_scan))
1140 }
1141 })
1142 }
1143 LogicalPlan::Extension(extension_plan) => {
1144 if extension_plan.node.inputs().is_empty() {
1147 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1148 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1149 }
1150 let prevent_cols =
1151 extension_plan.node.prevent_predicate_push_down_columns();
1152
1153 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1157 .iter()
1158 .map(|expr| {
1159 let cols = expr.column_refs();
1160 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1161 Ok(false) } else {
1163 Ok(true) }
1165 })
1166 .collect::<Result<Vec<_>>>()?;
1167
1168 if predicate_push_or_keep.iter().all(|&x| !x) {
1170 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1171 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1172 }
1173
1174 let mut keep_predicates = vec![];
1176 let mut push_predicates = vec![];
1177 for (push, expr) in predicate_push_or_keep
1178 .into_iter()
1179 .zip(split_conjunction_owned(filter.predicate).into_iter())
1180 {
1181 if !push {
1182 keep_predicates.push(expr);
1183 } else {
1184 push_predicates.push(expr);
1185 }
1186 }
1187
1188 let new_children = match conjunction(push_predicates) {
1189 Some(predicate) => extension_plan
1190 .node
1191 .inputs()
1192 .into_iter()
1193 .map(|child| {
1194 Ok(LogicalPlan::Filter(Filter::try_new(
1195 predicate.clone(),
1196 Arc::new(child.clone()),
1197 )?))
1198 })
1199 .collect::<Result<Vec<_>>>()?,
1200 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1201 };
1202 let child_plan = LogicalPlan::Extension(extension_plan);
1204 let new_extension =
1205 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1206
1207 let new_plan = match conjunction(keep_predicates) {
1208 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1209 predicate,
1210 Arc::new(new_extension),
1211 )?),
1212 None => new_extension,
1213 };
1214 Ok(Transformed::yes(new_plan))
1215 }
1216 child => {
1217 filter.input = Arc::new(child);
1218 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1219 }
1220 }
1221 }
1222}
1223
1224fn rewrite_projection(
1252 predicates: Vec<Expr>,
1253 mut projection: Projection,
1254) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1255 let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1259 .schema
1260 .iter()
1261 .zip(projection.expr.iter())
1262 .map(|((qualifier, field), expr)| {
1263 let expr = expr.clone().unalias();
1265
1266 (qualified_name(qualifier, field.name()), expr)
1267 })
1268 .partition(|(_, value)| value.is_volatile());
1269
1270 let mut push_predicates = vec![];
1271 let mut keep_predicates = vec![];
1272 for expr in predicates {
1273 if contain(&expr, &volatile_map) {
1274 keep_predicates.push(expr);
1275 } else {
1276 push_predicates.push(expr);
1277 }
1278 }
1279
1280 match conjunction(push_predicates) {
1281 Some(expr) => {
1282 let new_filter = LogicalPlan::Filter(Filter::try_new(
1285 replace_cols_by_name(expr, &non_volatile_map)?,
1286 std::mem::take(&mut projection.input),
1287 )?);
1288
1289 projection.input = Arc::new(new_filter);
1290
1291 Ok((
1292 Transformed::yes(LogicalPlan::Projection(projection)),
1293 conjunction(keep_predicates),
1294 ))
1295 }
1296 None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1297 }
1298}
1299
1300pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1302 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1303}
1304
1305fn insert_below(
1319 plan: LogicalPlan,
1320 new_child: LogicalPlan,
1321) -> Result<Transformed<LogicalPlan>> {
1322 let mut new_child = Some(new_child);
1323 let transformed_plan = plan.map_children(|_child| {
1324 if let Some(new_child) = new_child.take() {
1325 Ok(Transformed::yes(new_child))
1326 } else {
1327 internal_err!("node had more than one input")
1329 }
1330 })?;
1331
1332 if new_child.is_some() {
1334 return internal_err!("node had no inputs");
1335 }
1336
1337 Ok(transformed_plan)
1338}
1339
1340impl PushDownFilter {
1341 #[allow(missing_docs)]
1342 pub fn new() -> Self {
1343 Self {}
1344 }
1345}
1346
1347pub fn replace_cols_by_name(
1349 e: Expr,
1350 replace_map: &HashMap<String, Expr>,
1351) -> Result<Expr> {
1352 e.transform_up(|expr| {
1353 Ok(if let Expr::Column(c) = &expr {
1354 match replace_map.get(&c.flat_name()) {
1355 Some(new_c) => Transformed::yes(new_c.clone()),
1356 None => Transformed::no(expr),
1357 }
1358 } else {
1359 Transformed::no(expr)
1360 })
1361 })
1362 .data()
1363}
1364
1365fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1367 let mut is_contain = false;
1368 e.apply(|expr| {
1369 Ok(if let Expr::Column(c) = &expr {
1370 match check_map.get(&c.flat_name()) {
1371 Some(_) => {
1372 is_contain = true;
1373 TreeNodeRecursion::Stop
1374 }
1375 None => TreeNodeRecursion::Continue,
1376 }
1377 } else {
1378 TreeNodeRecursion::Continue
1379 })
1380 })
1381 .unwrap();
1382 is_contain
1383}
1384
1385#[cfg(test)]
1386mod tests {
1387 use std::any::Any;
1388 use std::cmp::Ordering;
1389 use std::fmt::{Debug, Formatter};
1390
1391 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1392 use async_trait::async_trait;
1393
1394 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1395 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1396 use datafusion_expr::logical_plan::table_scan;
1397 use datafusion_expr::{
1398 col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1399 LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1400 TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1401 WindowFunctionDefinition,
1402 };
1403
1404 use crate::assert_optimized_plan_eq_snapshot;
1405 use crate::optimizer::Optimizer;
1406 use crate::simplify_expressions::SimplifyExpressions;
1407 use crate::test::*;
1408 use crate::OptimizerContext;
1409 use datafusion_expr::test::function_stub::sum;
1410 use insta::assert_snapshot;
1411
1412 use super::*;
1413
1414 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1415
1416 macro_rules! assert_optimized_plan_equal {
1417 (
1418 $plan:expr,
1419 @ $expected:literal $(,)?
1420 ) => {{
1421 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1422 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1423 assert_optimized_plan_eq_snapshot!(
1424 optimizer_ctx,
1425 rules,
1426 $plan,
1427 @ $expected,
1428 )
1429 }};
1430 }
1431
1432 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1433 (
1434 $plan:expr,
1435 @ $expected:literal $(,)?
1436 ) => {{
1437 let optimizer = Optimizer::with_rules(vec![
1438 Arc::new(SimplifyExpressions::new()),
1439 Arc::new(PushDownFilter::new()),
1440 ]);
1441 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1442 assert_snapshot!(optimized_plan, @ $expected);
1443 Ok::<(), DataFusionError>(())
1444 }};
1445 }
1446
1447 #[test]
1448 fn filter_before_projection() -> Result<()> {
1449 let table_scan = test_table_scan()?;
1450 let plan = LogicalPlanBuilder::from(table_scan)
1451 .project(vec![col("a"), col("b")])?
1452 .filter(col("a").eq(lit(1i64)))?
1453 .build()?;
1454 assert_optimized_plan_equal!(
1456 plan,
1457 @r"
1458 Projection: test.a, test.b
1459 TableScan: test, full_filters=[test.a = Int64(1)]
1460 "
1461 )
1462 }
1463
1464 #[test]
1465 fn filter_after_limit() -> Result<()> {
1466 let table_scan = test_table_scan()?;
1467 let plan = LogicalPlanBuilder::from(table_scan)
1468 .project(vec![col("a"), col("b")])?
1469 .limit(0, Some(10))?
1470 .filter(col("a").eq(lit(1i64)))?
1471 .build()?;
1472 assert_optimized_plan_equal!(
1474 plan,
1475 @r"
1476 Filter: test.a = Int64(1)
1477 Limit: skip=0, fetch=10
1478 Projection: test.a, test.b
1479 TableScan: test
1480 "
1481 )
1482 }
1483
1484 #[test]
1485 fn filter_no_columns() -> Result<()> {
1486 let table_scan = test_table_scan()?;
1487 let plan = LogicalPlanBuilder::from(table_scan)
1488 .filter(lit(0i64).eq(lit(1i64)))?
1489 .build()?;
1490 assert_optimized_plan_equal!(
1491 plan,
1492 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1493 )
1494 }
1495
1496 #[test]
1497 fn filter_jump_2_plans() -> Result<()> {
1498 let table_scan = test_table_scan()?;
1499 let plan = LogicalPlanBuilder::from(table_scan)
1500 .project(vec![col("a"), col("b"), col("c")])?
1501 .project(vec![col("c"), col("b")])?
1502 .filter(col("a").eq(lit(1i64)))?
1503 .build()?;
1504 assert_optimized_plan_equal!(
1506 plan,
1507 @r"
1508 Projection: test.c, test.b
1509 Projection: test.a, test.b, test.c
1510 TableScan: test, full_filters=[test.a = Int64(1)]
1511 "
1512 )
1513 }
1514
1515 #[test]
1516 fn filter_move_agg() -> Result<()> {
1517 let table_scan = test_table_scan()?;
1518 let plan = LogicalPlanBuilder::from(table_scan)
1519 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1520 .filter(col("a").gt(lit(10i64)))?
1521 .build()?;
1522 assert_optimized_plan_equal!(
1524 plan,
1525 @r"
1526 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1527 TableScan: test, full_filters=[test.a > Int64(10)]
1528 "
1529 )
1530 }
1531
1532 #[test]
1533 fn filter_complex_group_by() -> Result<()> {
1534 let table_scan = test_table_scan()?;
1535 let plan = LogicalPlanBuilder::from(table_scan)
1536 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1537 .filter(col("b").gt(lit(10i64)))?
1538 .build()?;
1539 assert_optimized_plan_equal!(
1540 plan,
1541 @r"
1542 Filter: test.b > Int64(10)
1543 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1544 TableScan: test
1545 "
1546 )
1547 }
1548
1549 #[test]
1550 fn push_agg_need_replace_expr() -> Result<()> {
1551 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1552 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1553 .filter(col("test.b + test.a").gt(lit(10i64)))?
1554 .build()?;
1555 assert_optimized_plan_equal!(
1556 plan,
1557 @r"
1558 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1559 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1560 "
1561 )
1562 }
1563
1564 #[test]
1565 fn filter_keep_agg() -> Result<()> {
1566 let table_scan = test_table_scan()?;
1567 let plan = LogicalPlanBuilder::from(table_scan)
1568 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1569 .filter(col("b").gt(lit(10i64)))?
1570 .build()?;
1571 assert_optimized_plan_equal!(
1573 plan,
1574 @r"
1575 Filter: b > Int64(10)
1576 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1577 TableScan: test
1578 "
1579 )
1580 }
1581
1582 #[test]
1584 fn filter_move_window() -> Result<()> {
1585 let table_scan = test_table_scan()?;
1586
1587 let window = Expr::from(WindowFunction::new(
1588 WindowFunctionDefinition::WindowUDF(
1589 datafusion_functions_window::rank::rank_udwf(),
1590 ),
1591 vec![],
1592 ))
1593 .partition_by(vec![col("a"), col("b")])
1594 .order_by(vec![col("c").sort(true, true)])
1595 .build()
1596 .unwrap();
1597
1598 let plan = LogicalPlanBuilder::from(table_scan)
1599 .window(vec![window])?
1600 .filter(col("b").gt(lit(10i64)))?
1601 .build()?;
1602
1603 assert_optimized_plan_equal!(
1604 plan,
1605 @r"
1606 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1607 TableScan: test, full_filters=[test.b > Int64(10)]
1608 "
1609 )
1610 }
1611
1612 #[test]
1615 fn filter_move_complex_window() -> Result<()> {
1616 let table_scan = test_table_scan()?;
1617
1618 let window = Expr::from(WindowFunction::new(
1619 WindowFunctionDefinition::WindowUDF(
1620 datafusion_functions_window::rank::rank_udwf(),
1621 ),
1622 vec![],
1623 ))
1624 .partition_by(vec![col("a"), col("b")])
1625 .order_by(vec![col("c").sort(true, true)])
1626 .build()
1627 .unwrap();
1628
1629 let plan = LogicalPlanBuilder::from(table_scan)
1630 .window(vec![window])?
1631 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1632 .build()?;
1633
1634 assert_optimized_plan_equal!(
1635 plan,
1636 @r"
1637 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1638 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1639 "
1640 )
1641 }
1642
1643 #[test]
1645 fn filter_move_partial_window() -> Result<()> {
1646 let table_scan = test_table_scan()?;
1647
1648 let window = Expr::from(WindowFunction::new(
1649 WindowFunctionDefinition::WindowUDF(
1650 datafusion_functions_window::rank::rank_udwf(),
1651 ),
1652 vec![],
1653 ))
1654 .partition_by(vec![col("a")])
1655 .order_by(vec![col("c").sort(true, true)])
1656 .build()
1657 .unwrap();
1658
1659 let plan = LogicalPlanBuilder::from(table_scan)
1660 .window(vec![window])?
1661 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1662 .build()?;
1663
1664 assert_optimized_plan_equal!(
1665 plan,
1666 @r"
1667 Filter: test.b = Int64(1)
1668 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1669 TableScan: test, full_filters=[test.a > Int64(10)]
1670 "
1671 )
1672 }
1673
1674 #[test]
1677 fn filter_expression_keep_window() -> Result<()> {
1678 let table_scan = test_table_scan()?;
1679
1680 let window = Expr::from(WindowFunction::new(
1681 WindowFunctionDefinition::WindowUDF(
1682 datafusion_functions_window::rank::rank_udwf(),
1683 ),
1684 vec![],
1685 ))
1686 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1688 .build()
1689 .unwrap();
1690
1691 let plan = LogicalPlanBuilder::from(table_scan)
1692 .window(vec![window])?
1693 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1696 .build()?;
1697
1698 assert_optimized_plan_equal!(
1699 plan,
1700 @r"
1701 Filter: test.a + test.b > Int64(10)
1702 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1703 TableScan: test
1704 "
1705 )
1706 }
1707
1708 #[test]
1710 fn filter_order_keep_window() -> Result<()> {
1711 let table_scan = test_table_scan()?;
1712
1713 let window = Expr::from(WindowFunction::new(
1714 WindowFunctionDefinition::WindowUDF(
1715 datafusion_functions_window::rank::rank_udwf(),
1716 ),
1717 vec![],
1718 ))
1719 .partition_by(vec![col("a")])
1720 .order_by(vec![col("c").sort(true, true)])
1721 .build()
1722 .unwrap();
1723
1724 let plan = LogicalPlanBuilder::from(table_scan)
1725 .window(vec![window])?
1726 .filter(col("c").gt(lit(10i64)))?
1727 .build()?;
1728
1729 assert_optimized_plan_equal!(
1730 plan,
1731 @r"
1732 Filter: test.c > Int64(10)
1733 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1734 TableScan: test
1735 "
1736 )
1737 }
1738
1739 #[test]
1742 fn filter_multiple_windows_common_partitions() -> Result<()> {
1743 let table_scan = test_table_scan()?;
1744
1745 let window1 = Expr::from(WindowFunction::new(
1746 WindowFunctionDefinition::WindowUDF(
1747 datafusion_functions_window::rank::rank_udwf(),
1748 ),
1749 vec![],
1750 ))
1751 .partition_by(vec![col("a")])
1752 .order_by(vec![col("c").sort(true, true)])
1753 .build()
1754 .unwrap();
1755
1756 let window2 = Expr::from(WindowFunction::new(
1757 WindowFunctionDefinition::WindowUDF(
1758 datafusion_functions_window::rank::rank_udwf(),
1759 ),
1760 vec![],
1761 ))
1762 .partition_by(vec![col("b"), col("a")])
1763 .order_by(vec![col("c").sort(true, true)])
1764 .build()
1765 .unwrap();
1766
1767 let plan = LogicalPlanBuilder::from(table_scan)
1768 .window(vec![window1, window2])?
1769 .filter(col("a").gt(lit(10i64)))? .build()?;
1771
1772 assert_optimized_plan_equal!(
1773 plan,
1774 @r"
1775 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1776 TableScan: test, full_filters=[test.a > Int64(10)]
1777 "
1778 )
1779 }
1780
1781 #[test]
1784 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1785 let table_scan = test_table_scan()?;
1786
1787 let window1 = Expr::from(WindowFunction::new(
1788 WindowFunctionDefinition::WindowUDF(
1789 datafusion_functions_window::rank::rank_udwf(),
1790 ),
1791 vec![],
1792 ))
1793 .partition_by(vec![col("a")])
1794 .order_by(vec![col("c").sort(true, true)])
1795 .build()
1796 .unwrap();
1797
1798 let window2 = Expr::from(WindowFunction::new(
1799 WindowFunctionDefinition::WindowUDF(
1800 datafusion_functions_window::rank::rank_udwf(),
1801 ),
1802 vec![],
1803 ))
1804 .partition_by(vec![col("b"), col("a")])
1805 .order_by(vec![col("c").sort(true, true)])
1806 .build()
1807 .unwrap();
1808
1809 let plan = LogicalPlanBuilder::from(table_scan)
1810 .window(vec![window1, window2])?
1811 .filter(col("b").gt(lit(10i64)))? .build()?;
1813
1814 assert_optimized_plan_equal!(
1815 plan,
1816 @r"
1817 Filter: test.b > Int64(10)
1818 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1819 TableScan: test
1820 "
1821 )
1822 }
1823
1824 #[test]
1826 fn alias() -> Result<()> {
1827 let table_scan = test_table_scan()?;
1828 let plan = LogicalPlanBuilder::from(table_scan)
1829 .project(vec![col("a").alias("b"), col("c")])?
1830 .filter(col("b").eq(lit(1i64)))?
1831 .build()?;
1832 assert_optimized_plan_equal!(
1834 plan,
1835 @r"
1836 Projection: test.a AS b, test.c
1837 TableScan: test, full_filters=[test.a = Int64(1)]
1838 "
1839 )
1840 }
1841
1842 fn add(left: Expr, right: Expr) -> Expr {
1843 Expr::BinaryExpr(BinaryExpr::new(
1844 Box::new(left),
1845 Operator::Plus,
1846 Box::new(right),
1847 ))
1848 }
1849
1850 fn multiply(left: Expr, right: Expr) -> Expr {
1851 Expr::BinaryExpr(BinaryExpr::new(
1852 Box::new(left),
1853 Operator::Multiply,
1854 Box::new(right),
1855 ))
1856 }
1857
1858 #[test]
1860 fn complex_expression() -> Result<()> {
1861 let table_scan = test_table_scan()?;
1862 let plan = LogicalPlanBuilder::from(table_scan)
1863 .project(vec![
1864 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1865 col("c"),
1866 ])?
1867 .filter(col("b").eq(lit(1i64)))?
1868 .build()?;
1869
1870 assert_snapshot!(plan,
1872 @r"
1873 Filter: b = Int64(1)
1874 Projection: test.a * Int32(2) + test.c AS b, test.c
1875 TableScan: test
1876 ",
1877 );
1878 assert_optimized_plan_equal!(
1880 plan,
1881 @r"
1882 Projection: test.a * Int32(2) + test.c AS b, test.c
1883 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1884 "
1885 )
1886 }
1887
1888 #[test]
1890 fn complex_plan() -> Result<()> {
1891 let table_scan = test_table_scan()?;
1892 let plan = LogicalPlanBuilder::from(table_scan)
1893 .project(vec![
1894 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1895 col("c"),
1896 ])?
1897 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1899 .filter(col("a").eq(lit(1i64)))?
1900 .build()?;
1901
1902 assert_snapshot!(plan,
1904 @r"
1905 Filter: a = Int64(1)
1906 Projection: b * Int32(3) AS a, test.c
1907 Projection: test.a * Int32(2) + test.c AS b, test.c
1908 TableScan: test
1909 ",
1910 );
1911 assert_optimized_plan_equal!(
1913 plan,
1914 @r"
1915 Projection: b * Int32(3) AS a, test.c
1916 Projection: test.a * Int32(2) + test.c AS b, test.c
1917 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
1918 "
1919 )
1920 }
1921
1922 #[derive(Debug, PartialEq, Eq, Hash)]
1923 struct NoopPlan {
1924 input: Vec<LogicalPlan>,
1925 schema: DFSchemaRef,
1926 }
1927
1928 impl PartialOrd for NoopPlan {
1930 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1931 self.input.partial_cmp(&other.input)
1932 }
1933 }
1934
1935 impl UserDefinedLogicalNodeCore for NoopPlan {
1936 fn name(&self) -> &str {
1937 "NoopPlan"
1938 }
1939
1940 fn inputs(&self) -> Vec<&LogicalPlan> {
1941 self.input.iter().collect()
1942 }
1943
1944 fn schema(&self) -> &DFSchemaRef {
1945 &self.schema
1946 }
1947
1948 fn expressions(&self) -> Vec<Expr> {
1949 self.input
1950 .iter()
1951 .flat_map(|child| child.expressions())
1952 .collect()
1953 }
1954
1955 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
1956 HashSet::from_iter(vec!["c".to_string()])
1957 }
1958
1959 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1960 write!(f, "NoopPlan")
1961 }
1962
1963 fn with_exprs_and_inputs(
1964 &self,
1965 _exprs: Vec<Expr>,
1966 inputs: Vec<LogicalPlan>,
1967 ) -> Result<Self> {
1968 Ok(Self {
1969 input: inputs,
1970 schema: Arc::clone(&self.schema),
1971 })
1972 }
1973
1974 fn supports_limit_pushdown(&self) -> bool {
1975 false }
1977 }
1978
1979 #[test]
1980 fn user_defined_plan() -> Result<()> {
1981 let table_scan = test_table_scan()?;
1982
1983 let custom_plan = LogicalPlan::Extension(Extension {
1984 node: Arc::new(NoopPlan {
1985 input: vec![table_scan.clone()],
1986 schema: Arc::clone(table_scan.schema()),
1987 }),
1988 });
1989 let plan = LogicalPlanBuilder::from(custom_plan)
1990 .filter(col("a").eq(lit(1i64)))?
1991 .build()?;
1992
1993 assert_optimized_plan_equal!(
1995 plan,
1996 @r"
1997 NoopPlan
1998 TableScan: test, full_filters=[test.a = Int64(1)]
1999 "
2000 )?;
2001
2002 let custom_plan = LogicalPlan::Extension(Extension {
2003 node: Arc::new(NoopPlan {
2004 input: vec![table_scan.clone()],
2005 schema: Arc::clone(table_scan.schema()),
2006 }),
2007 });
2008 let plan = LogicalPlanBuilder::from(custom_plan)
2009 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2010 .build()?;
2011
2012 assert_optimized_plan_equal!(
2014 plan,
2015 @r"
2016 Filter: test.c = Int64(2)
2017 NoopPlan
2018 TableScan: test, full_filters=[test.a = Int64(1)]
2019 "
2020 )?;
2021
2022 let custom_plan = LogicalPlan::Extension(Extension {
2023 node: Arc::new(NoopPlan {
2024 input: vec![table_scan.clone(), table_scan.clone()],
2025 schema: Arc::clone(table_scan.schema()),
2026 }),
2027 });
2028 let plan = LogicalPlanBuilder::from(custom_plan)
2029 .filter(col("a").eq(lit(1i64)))?
2030 .build()?;
2031
2032 assert_optimized_plan_equal!(
2034 plan,
2035 @r"
2036 NoopPlan
2037 TableScan: test, full_filters=[test.a = Int64(1)]
2038 TableScan: test, full_filters=[test.a = Int64(1)]
2039 "
2040 )?;
2041
2042 let custom_plan = LogicalPlan::Extension(Extension {
2043 node: Arc::new(NoopPlan {
2044 input: vec![table_scan.clone(), table_scan.clone()],
2045 schema: Arc::clone(table_scan.schema()),
2046 }),
2047 });
2048 let plan = LogicalPlanBuilder::from(custom_plan)
2049 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2050 .build()?;
2051
2052 assert_optimized_plan_equal!(
2054 plan,
2055 @r"
2056 Filter: test.c = Int64(2)
2057 NoopPlan
2058 TableScan: test, full_filters=[test.a = Int64(1)]
2059 TableScan: test, full_filters=[test.a = Int64(1)]
2060 "
2061 )
2062 }
2063
2064 #[test]
2067 fn multi_filter() -> Result<()> {
2068 let table_scan = test_table_scan()?;
2070 let plan = LogicalPlanBuilder::from(table_scan)
2071 .project(vec![col("a").alias("b"), col("c")])?
2072 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2073 .filter(col("b").gt(lit(10i64)))?
2074 .filter(col("sum(test.c)").gt(lit(10i64)))?
2075 .build()?;
2076
2077 assert_snapshot!(plan,
2079 @r"
2080 Filter: sum(test.c) > Int64(10)
2081 Filter: b > Int64(10)
2082 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2083 Projection: test.a AS b, test.c
2084 TableScan: test
2085 ",
2086 );
2087 assert_optimized_plan_equal!(
2089 plan,
2090 @r"
2091 Filter: sum(test.c) > Int64(10)
2092 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2093 Projection: test.a AS b, test.c
2094 TableScan: test, full_filters=[test.a > Int64(10)]
2095 "
2096 )
2097 }
2098
2099 #[test]
2102 fn split_filter() -> Result<()> {
2103 let table_scan = test_table_scan()?;
2105 let plan = LogicalPlanBuilder::from(table_scan)
2106 .project(vec![col("a").alias("b"), col("c")])?
2107 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2108 .filter(and(
2109 col("sum(test.c)").gt(lit(10i64)),
2110 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2111 ))?
2112 .build()?;
2113
2114 assert_snapshot!(plan,
2116 @r"
2117 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2118 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2119 Projection: test.a AS b, test.c
2120 TableScan: test
2121 ",
2122 );
2123 assert_optimized_plan_equal!(
2125 plan,
2126 @r"
2127 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2128 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2129 Projection: test.a AS b, test.c
2130 TableScan: test, full_filters=[test.a > Int64(10)]
2131 "
2132 )
2133 }
2134
2135 #[test]
2137 fn double_limit() -> Result<()> {
2138 let table_scan = test_table_scan()?;
2139 let plan = LogicalPlanBuilder::from(table_scan)
2140 .project(vec![col("a"), col("b")])?
2141 .limit(0, Some(20))?
2142 .limit(0, Some(10))?
2143 .project(vec![col("a"), col("b")])?
2144 .filter(col("a").eq(lit(1i64)))?
2145 .build()?;
2146 assert_optimized_plan_equal!(
2148 plan,
2149 @r"
2150 Projection: test.a, test.b
2151 Filter: test.a = Int64(1)
2152 Limit: skip=0, fetch=10
2153 Limit: skip=0, fetch=20
2154 Projection: test.a, test.b
2155 TableScan: test
2156 "
2157 )
2158 }
2159
2160 #[test]
2161 fn union_all() -> Result<()> {
2162 let table_scan = test_table_scan()?;
2163 let table_scan2 = test_table_scan_with_name("test2")?;
2164 let plan = LogicalPlanBuilder::from(table_scan)
2165 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2166 .filter(col("a").eq(lit(1i64)))?
2167 .build()?;
2168 assert_optimized_plan_equal!(
2170 plan,
2171 @r"
2172 Union
2173 TableScan: test, full_filters=[test.a = Int64(1)]
2174 TableScan: test2, full_filters=[test2.a = Int64(1)]
2175 "
2176 )
2177 }
2178
2179 #[test]
2180 fn union_all_on_projection() -> Result<()> {
2181 let table_scan = test_table_scan()?;
2182 let table = LogicalPlanBuilder::from(table_scan)
2183 .project(vec![col("a").alias("b")])?
2184 .alias("test2")?;
2185
2186 let plan = table
2187 .clone()
2188 .union(table.build()?)?
2189 .filter(col("b").eq(lit(1i64)))?
2190 .build()?;
2191
2192 assert_optimized_plan_equal!(
2194 plan,
2195 @r"
2196 Union
2197 SubqueryAlias: test2
2198 Projection: test.a AS b
2199 TableScan: test, full_filters=[test.a = Int64(1)]
2200 SubqueryAlias: test2
2201 Projection: test.a AS b
2202 TableScan: test, full_filters=[test.a = Int64(1)]
2203 "
2204 )
2205 }
2206
2207 #[test]
2208 fn test_union_different_schema() -> Result<()> {
2209 let left = LogicalPlanBuilder::from(test_table_scan()?)
2210 .project(vec![col("a"), col("b"), col("c")])?
2211 .build()?;
2212
2213 let schema = Schema::new(vec![
2214 Field::new("d", DataType::UInt32, false),
2215 Field::new("e", DataType::UInt32, false),
2216 Field::new("f", DataType::UInt32, false),
2217 ]);
2218 let right = table_scan(Some("test1"), &schema, None)?
2219 .project(vec![col("d"), col("e"), col("f")])?
2220 .build()?;
2221 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2222 let plan = LogicalPlanBuilder::from(left)
2223 .cross_join(right)?
2224 .project(vec![col("test.a"), col("test1.d")])?
2225 .filter(filter)?
2226 .build()?;
2227
2228 assert_optimized_plan_equal!(
2229 plan,
2230 @r"
2231 Projection: test.a, test1.d
2232 Cross Join:
2233 Projection: test.a, test.b, test.c
2234 TableScan: test, full_filters=[test.a = Int32(1)]
2235 Projection: test1.d, test1.e, test1.f
2236 TableScan: test1, full_filters=[test1.d > Int32(2)]
2237 "
2238 )
2239 }
2240
2241 #[test]
2242 fn test_project_same_name_different_qualifier() -> Result<()> {
2243 let table_scan = test_table_scan()?;
2244 let left = LogicalPlanBuilder::from(table_scan)
2245 .project(vec![col("a"), col("b"), col("c")])?
2246 .build()?;
2247 let right_table_scan = test_table_scan_with_name("test1")?;
2248 let right = LogicalPlanBuilder::from(right_table_scan)
2249 .project(vec![col("a"), col("b"), col("c")])?
2250 .build()?;
2251 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2252 let plan = LogicalPlanBuilder::from(left)
2253 .cross_join(right)?
2254 .project(vec![col("test.a"), col("test1.a")])?
2255 .filter(filter)?
2256 .build()?;
2257
2258 assert_optimized_plan_equal!(
2259 plan,
2260 @r"
2261 Projection: test.a, test1.a
2262 Cross Join:
2263 Projection: test.a, test.b, test.c
2264 TableScan: test, full_filters=[test.a = Int32(1)]
2265 Projection: test1.a, test1.b, test1.c
2266 TableScan: test1, full_filters=[test1.a > Int32(2)]
2267 "
2268 )
2269 }
2270
2271 #[test]
2273 fn filter_2_breaks_limits() -> Result<()> {
2274 let table_scan = test_table_scan()?;
2275 let plan = LogicalPlanBuilder::from(table_scan)
2276 .project(vec![col("a")])?
2277 .filter(col("a").lt_eq(lit(1i64)))?
2278 .limit(0, Some(1))?
2279 .project(vec![col("a")])?
2280 .filter(col("a").gt_eq(lit(1i64)))?
2281 .build()?;
2282 assert_snapshot!(plan,
2286 @r"
2287 Filter: test.a >= Int64(1)
2288 Projection: test.a
2289 Limit: skip=0, fetch=1
2290 Filter: test.a <= Int64(1)
2291 Projection: test.a
2292 TableScan: test
2293 ",
2294 );
2295 assert_optimized_plan_equal!(
2296 plan,
2297 @r"
2298 Projection: test.a
2299 Filter: test.a >= Int64(1)
2300 Limit: skip=0, fetch=1
2301 Projection: test.a
2302 TableScan: test, full_filters=[test.a <= Int64(1)]
2303 "
2304 )
2305 }
2306
2307 #[test]
2309 fn two_filters_on_same_depth() -> Result<()> {
2310 let table_scan = test_table_scan()?;
2311 let plan = LogicalPlanBuilder::from(table_scan)
2312 .limit(0, Some(1))?
2313 .filter(col("a").lt_eq(lit(1i64)))?
2314 .filter(col("a").gt_eq(lit(1i64)))?
2315 .project(vec![col("a")])?
2316 .build()?;
2317
2318 assert_snapshot!(plan,
2320 @r"
2321 Projection: test.a
2322 Filter: test.a >= Int64(1)
2323 Filter: test.a <= Int64(1)
2324 Limit: skip=0, fetch=1
2325 TableScan: test
2326 ",
2327 );
2328 assert_optimized_plan_equal!(
2329 plan,
2330 @r"
2331 Projection: test.a
2332 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2333 Limit: skip=0, fetch=1
2334 TableScan: test
2335 "
2336 )
2337 }
2338
2339 #[test]
2342 fn filters_user_defined_node() -> Result<()> {
2343 let table_scan = test_table_scan()?;
2344 let plan = LogicalPlanBuilder::from(table_scan)
2345 .filter(col("a").lt_eq(lit(1i64)))?
2346 .build()?;
2347
2348 let plan = user_defined::new(plan);
2349
2350 assert_snapshot!(plan,
2352 @r"
2353 TestUserDefined
2354 Filter: test.a <= Int64(1)
2355 TableScan: test
2356 ",
2357 );
2358 assert_optimized_plan_equal!(
2359 plan,
2360 @r"
2361 TestUserDefined
2362 TableScan: test, full_filters=[test.a <= Int64(1)]
2363 "
2364 )
2365 }
2366
2367 #[test]
2369 fn filter_on_join_on_common_independent() -> Result<()> {
2370 let table_scan = test_table_scan()?;
2371 let left = LogicalPlanBuilder::from(table_scan).build()?;
2372 let right_table_scan = test_table_scan_with_name("test2")?;
2373 let right = LogicalPlanBuilder::from(right_table_scan)
2374 .project(vec![col("a")])?
2375 .build()?;
2376 let plan = LogicalPlanBuilder::from(left)
2377 .join(
2378 right,
2379 JoinType::Inner,
2380 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2381 None,
2382 )?
2383 .filter(col("test.a").lt_eq(lit(1i64)))?
2384 .build()?;
2385
2386 assert_snapshot!(plan,
2388 @r"
2389 Filter: test.a <= Int64(1)
2390 Inner Join: test.a = test2.a
2391 TableScan: test
2392 Projection: test2.a
2393 TableScan: test2
2394 ",
2395 );
2396 assert_optimized_plan_equal!(
2398 plan,
2399 @r"
2400 Inner Join: test.a = test2.a
2401 TableScan: test, full_filters=[test.a <= Int64(1)]
2402 Projection: test2.a
2403 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2404 "
2405 )
2406 }
2407
2408 #[test]
2410 fn filter_using_join_on_common_independent() -> Result<()> {
2411 let table_scan = test_table_scan()?;
2412 let left = LogicalPlanBuilder::from(table_scan).build()?;
2413 let right_table_scan = test_table_scan_with_name("test2")?;
2414 let right = LogicalPlanBuilder::from(right_table_scan)
2415 .project(vec![col("a")])?
2416 .build()?;
2417 let plan = LogicalPlanBuilder::from(left)
2418 .join_using(
2419 right,
2420 JoinType::Inner,
2421 vec![Column::from_name("a".to_string())],
2422 )?
2423 .filter(col("a").lt_eq(lit(1i64)))?
2424 .build()?;
2425
2426 assert_snapshot!(plan,
2428 @r"
2429 Filter: test.a <= Int64(1)
2430 Inner Join: Using test.a = test2.a
2431 TableScan: test
2432 Projection: test2.a
2433 TableScan: test2
2434 ",
2435 );
2436 assert_optimized_plan_equal!(
2438 plan,
2439 @r"
2440 Inner Join: Using test.a = test2.a
2441 TableScan: test, full_filters=[test.a <= Int64(1)]
2442 Projection: test2.a
2443 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2444 "
2445 )
2446 }
2447
2448 #[test]
2450 fn filter_join_on_common_dependent() -> Result<()> {
2451 let table_scan = test_table_scan()?;
2452 let left = LogicalPlanBuilder::from(table_scan)
2453 .project(vec![col("a"), col("c")])?
2454 .build()?;
2455 let right_table_scan = test_table_scan_with_name("test2")?;
2456 let right = LogicalPlanBuilder::from(right_table_scan)
2457 .project(vec![col("a"), col("b")])?
2458 .build()?;
2459 let plan = LogicalPlanBuilder::from(left)
2460 .join(
2461 right,
2462 JoinType::Inner,
2463 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2464 None,
2465 )?
2466 .filter(col("c").lt_eq(col("b")))?
2467 .build()?;
2468
2469 assert_snapshot!(plan,
2471 @r"
2472 Filter: test.c <= test2.b
2473 Inner Join: test.a = test2.a
2474 Projection: test.a, test.c
2475 TableScan: test
2476 Projection: test2.a, test2.b
2477 TableScan: test2
2478 ",
2479 );
2480 assert_optimized_plan_equal!(
2482 plan,
2483 @r"
2484 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2485 Projection: test.a, test.c
2486 TableScan: test
2487 Projection: test2.a, test2.b
2488 TableScan: test2
2489 "
2490 )
2491 }
2492
2493 #[test]
2495 fn filter_join_on_one_side() -> Result<()> {
2496 let table_scan = test_table_scan()?;
2497 let left = LogicalPlanBuilder::from(table_scan)
2498 .project(vec![col("a"), col("b")])?
2499 .build()?;
2500 let table_scan_right = test_table_scan_with_name("test2")?;
2501 let right = LogicalPlanBuilder::from(table_scan_right)
2502 .project(vec![col("a"), col("c")])?
2503 .build()?;
2504
2505 let plan = LogicalPlanBuilder::from(left)
2506 .join(
2507 right,
2508 JoinType::Inner,
2509 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2510 None,
2511 )?
2512 .filter(col("b").lt_eq(lit(1i64)))?
2513 .build()?;
2514
2515 assert_snapshot!(plan,
2517 @r"
2518 Filter: test.b <= Int64(1)
2519 Inner Join: test.a = test2.a
2520 Projection: test.a, test.b
2521 TableScan: test
2522 Projection: test2.a, test2.c
2523 TableScan: test2
2524 ",
2525 );
2526 assert_optimized_plan_equal!(
2527 plan,
2528 @r"
2529 Inner Join: test.a = test2.a
2530 Projection: test.a, test.b
2531 TableScan: test, full_filters=[test.b <= Int64(1)]
2532 Projection: test2.a, test2.c
2533 TableScan: test2
2534 "
2535 )
2536 }
2537
2538 #[test]
2541 fn filter_using_left_join() -> Result<()> {
2542 let table_scan = test_table_scan()?;
2543 let left = LogicalPlanBuilder::from(table_scan).build()?;
2544 let right_table_scan = test_table_scan_with_name("test2")?;
2545 let right = LogicalPlanBuilder::from(right_table_scan)
2546 .project(vec![col("a")])?
2547 .build()?;
2548 let plan = LogicalPlanBuilder::from(left)
2549 .join_using(
2550 right,
2551 JoinType::Left,
2552 vec![Column::from_name("a".to_string())],
2553 )?
2554 .filter(col("test2.a").lt_eq(lit(1i64)))?
2555 .build()?;
2556
2557 assert_snapshot!(plan,
2559 @r"
2560 Filter: test2.a <= Int64(1)
2561 Left Join: Using test.a = test2.a
2562 TableScan: test
2563 Projection: test2.a
2564 TableScan: test2
2565 ",
2566 );
2567 assert_optimized_plan_equal!(
2569 plan,
2570 @r"
2571 Filter: test2.a <= Int64(1)
2572 Left Join: Using test.a = test2.a
2573 TableScan: test, full_filters=[test.a <= Int64(1)]
2574 Projection: test2.a
2575 TableScan: test2
2576 "
2577 )
2578 }
2579
2580 #[test]
2582 fn filter_using_right_join() -> Result<()> {
2583 let table_scan = test_table_scan()?;
2584 let left = LogicalPlanBuilder::from(table_scan).build()?;
2585 let right_table_scan = test_table_scan_with_name("test2")?;
2586 let right = LogicalPlanBuilder::from(right_table_scan)
2587 .project(vec![col("a")])?
2588 .build()?;
2589 let plan = LogicalPlanBuilder::from(left)
2590 .join_using(
2591 right,
2592 JoinType::Right,
2593 vec![Column::from_name("a".to_string())],
2594 )?
2595 .filter(col("test.a").lt_eq(lit(1i64)))?
2596 .build()?;
2597
2598 assert_snapshot!(plan,
2600 @r"
2601 Filter: test.a <= Int64(1)
2602 Right Join: Using test.a = test2.a
2603 TableScan: test
2604 Projection: test2.a
2605 TableScan: test2
2606 ",
2607 );
2608 assert_optimized_plan_equal!(
2610 plan,
2611 @r"
2612 Filter: test.a <= Int64(1)
2613 Right Join: Using test.a = test2.a
2614 TableScan: test
2615 Projection: test2.a
2616 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2617 "
2618 )
2619 }
2620
2621 #[test]
2624 fn filter_using_left_join_on_common() -> Result<()> {
2625 let table_scan = test_table_scan()?;
2626 let left = LogicalPlanBuilder::from(table_scan).build()?;
2627 let right_table_scan = test_table_scan_with_name("test2")?;
2628 let right = LogicalPlanBuilder::from(right_table_scan)
2629 .project(vec![col("a")])?
2630 .build()?;
2631 let plan = LogicalPlanBuilder::from(left)
2632 .join_using(
2633 right,
2634 JoinType::Left,
2635 vec![Column::from_name("a".to_string())],
2636 )?
2637 .filter(col("a").lt_eq(lit(1i64)))?
2638 .build()?;
2639
2640 assert_snapshot!(plan,
2642 @r"
2643 Filter: test.a <= Int64(1)
2644 Left Join: Using test.a = test2.a
2645 TableScan: test
2646 Projection: test2.a
2647 TableScan: test2
2648 ",
2649 );
2650 assert_optimized_plan_equal!(
2652 plan,
2653 @r"
2654 Left Join: Using test.a = test2.a
2655 TableScan: test, full_filters=[test.a <= Int64(1)]
2656 Projection: test2.a
2657 TableScan: test2
2658 "
2659 )
2660 }
2661
2662 #[test]
2665 fn filter_using_right_join_on_common() -> Result<()> {
2666 let table_scan = test_table_scan()?;
2667 let left = LogicalPlanBuilder::from(table_scan).build()?;
2668 let right_table_scan = test_table_scan_with_name("test2")?;
2669 let right = LogicalPlanBuilder::from(right_table_scan)
2670 .project(vec![col("a")])?
2671 .build()?;
2672 let plan = LogicalPlanBuilder::from(left)
2673 .join_using(
2674 right,
2675 JoinType::Right,
2676 vec![Column::from_name("a".to_string())],
2677 )?
2678 .filter(col("test2.a").lt_eq(lit(1i64)))?
2679 .build()?;
2680
2681 assert_snapshot!(plan,
2683 @r"
2684 Filter: test2.a <= Int64(1)
2685 Right Join: Using test.a = test2.a
2686 TableScan: test
2687 Projection: test2.a
2688 TableScan: test2
2689 ",
2690 );
2691 assert_optimized_plan_equal!(
2693 plan,
2694 @r"
2695 Right Join: Using test.a = test2.a
2696 TableScan: test
2697 Projection: test2.a
2698 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2699 "
2700 )
2701 }
2702
2703 #[test]
2705 fn join_on_with_filter() -> Result<()> {
2706 let table_scan = test_table_scan()?;
2707 let left = LogicalPlanBuilder::from(table_scan)
2708 .project(vec![col("a"), col("b"), col("c")])?
2709 .build()?;
2710 let right_table_scan = test_table_scan_with_name("test2")?;
2711 let right = LogicalPlanBuilder::from(right_table_scan)
2712 .project(vec![col("a"), col("b"), col("c")])?
2713 .build()?;
2714 let filter = col("test.c")
2715 .gt(lit(1u32))
2716 .and(col("test.b").lt(col("test2.b")))
2717 .and(col("test2.c").gt(lit(4u32)));
2718 let plan = LogicalPlanBuilder::from(left)
2719 .join(
2720 right,
2721 JoinType::Inner,
2722 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2723 Some(filter),
2724 )?
2725 .build()?;
2726
2727 assert_snapshot!(plan,
2729 @r"
2730 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2731 Projection: test.a, test.b, test.c
2732 TableScan: test
2733 Projection: test2.a, test2.b, test2.c
2734 TableScan: test2
2735 ",
2736 );
2737 assert_optimized_plan_equal!(
2738 plan,
2739 @r"
2740 Inner Join: test.a = test2.a Filter: test.b < test2.b
2741 Projection: test.a, test.b, test.c
2742 TableScan: test, full_filters=[test.c > UInt32(1)]
2743 Projection: test2.a, test2.b, test2.c
2744 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2745 "
2746 )
2747 }
2748
2749 #[test]
2751 fn join_filter_removed() -> Result<()> {
2752 let table_scan = test_table_scan()?;
2753 let left = LogicalPlanBuilder::from(table_scan)
2754 .project(vec![col("a"), col("b"), col("c")])?
2755 .build()?;
2756 let right_table_scan = test_table_scan_with_name("test2")?;
2757 let right = LogicalPlanBuilder::from(right_table_scan)
2758 .project(vec![col("a"), col("b"), col("c")])?
2759 .build()?;
2760 let filter = col("test.b")
2761 .gt(lit(1u32))
2762 .and(col("test2.c").gt(lit(4u32)));
2763 let plan = LogicalPlanBuilder::from(left)
2764 .join(
2765 right,
2766 JoinType::Inner,
2767 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2768 Some(filter),
2769 )?
2770 .build()?;
2771
2772 assert_snapshot!(plan,
2774 @r"
2775 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2776 Projection: test.a, test.b, test.c
2777 TableScan: test
2778 Projection: test2.a, test2.b, test2.c
2779 TableScan: test2
2780 ",
2781 );
2782 assert_optimized_plan_equal!(
2783 plan,
2784 @r"
2785 Inner Join: test.a = test2.a
2786 Projection: test.a, test.b, test.c
2787 TableScan: test, full_filters=[test.b > UInt32(1)]
2788 Projection: test2.a, test2.b, test2.c
2789 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2790 "
2791 )
2792 }
2793
2794 #[test]
2796 fn join_filter_on_common() -> Result<()> {
2797 let table_scan = test_table_scan()?;
2798 let left = LogicalPlanBuilder::from(table_scan)
2799 .project(vec![col("a")])?
2800 .build()?;
2801 let right_table_scan = test_table_scan_with_name("test2")?;
2802 let right = LogicalPlanBuilder::from(right_table_scan)
2803 .project(vec![col("b")])?
2804 .build()?;
2805 let filter = col("test.a").gt(lit(1u32));
2806 let plan = LogicalPlanBuilder::from(left)
2807 .join(
2808 right,
2809 JoinType::Inner,
2810 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2811 Some(filter),
2812 )?
2813 .build()?;
2814
2815 assert_snapshot!(plan,
2817 @r"
2818 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2819 Projection: test.a
2820 TableScan: test
2821 Projection: test2.b
2822 TableScan: test2
2823 ",
2824 );
2825 assert_optimized_plan_equal!(
2826 plan,
2827 @r"
2828 Inner Join: test.a = test2.b
2829 Projection: test.a
2830 TableScan: test, full_filters=[test.a > UInt32(1)]
2831 Projection: test2.b
2832 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2833 "
2834 )
2835 }
2836
2837 #[test]
2839 fn left_join_on_with_filter() -> Result<()> {
2840 let table_scan = test_table_scan()?;
2841 let left = LogicalPlanBuilder::from(table_scan)
2842 .project(vec![col("a"), col("b"), col("c")])?
2843 .build()?;
2844 let right_table_scan = test_table_scan_with_name("test2")?;
2845 let right = LogicalPlanBuilder::from(right_table_scan)
2846 .project(vec![col("a"), col("b"), col("c")])?
2847 .build()?;
2848 let filter = col("test.a")
2849 .gt(lit(1u32))
2850 .and(col("test.b").lt(col("test2.b")))
2851 .and(col("test2.c").gt(lit(4u32)));
2852 let plan = LogicalPlanBuilder::from(left)
2853 .join(
2854 right,
2855 JoinType::Left,
2856 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2857 Some(filter),
2858 )?
2859 .build()?;
2860
2861 assert_snapshot!(plan,
2863 @r"
2864 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2865 Projection: test.a, test.b, test.c
2866 TableScan: test
2867 Projection: test2.a, test2.b, test2.c
2868 TableScan: test2
2869 ",
2870 );
2871 assert_optimized_plan_equal!(
2872 plan,
2873 @r"
2874 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2875 Projection: test.a, test.b, test.c
2876 TableScan: test
2877 Projection: test2.a, test2.b, test2.c
2878 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2879 "
2880 )
2881 }
2882
2883 #[test]
2885 fn right_join_on_with_filter() -> Result<()> {
2886 let table_scan = test_table_scan()?;
2887 let left = LogicalPlanBuilder::from(table_scan)
2888 .project(vec![col("a"), col("b"), col("c")])?
2889 .build()?;
2890 let right_table_scan = test_table_scan_with_name("test2")?;
2891 let right = LogicalPlanBuilder::from(right_table_scan)
2892 .project(vec![col("a"), col("b"), col("c")])?
2893 .build()?;
2894 let filter = col("test.a")
2895 .gt(lit(1u32))
2896 .and(col("test.b").lt(col("test2.b")))
2897 .and(col("test2.c").gt(lit(4u32)));
2898 let plan = LogicalPlanBuilder::from(left)
2899 .join(
2900 right,
2901 JoinType::Right,
2902 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2903 Some(filter),
2904 )?
2905 .build()?;
2906
2907 assert_snapshot!(plan,
2909 @r"
2910 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2911 Projection: test.a, test.b, test.c
2912 TableScan: test
2913 Projection: test2.a, test2.b, test2.c
2914 TableScan: test2
2915 ",
2916 );
2917 assert_optimized_plan_equal!(
2918 plan,
2919 @r"
2920 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
2921 Projection: test.a, test.b, test.c
2922 TableScan: test, full_filters=[test.a > UInt32(1)]
2923 Projection: test2.a, test2.b, test2.c
2924 TableScan: test2
2925 "
2926 )
2927 }
2928
2929 #[test]
2931 fn full_join_on_with_filter() -> Result<()> {
2932 let table_scan = test_table_scan()?;
2933 let left = LogicalPlanBuilder::from(table_scan)
2934 .project(vec![col("a"), col("b"), col("c")])?
2935 .build()?;
2936 let right_table_scan = test_table_scan_with_name("test2")?;
2937 let right = LogicalPlanBuilder::from(right_table_scan)
2938 .project(vec![col("a"), col("b"), col("c")])?
2939 .build()?;
2940 let filter = col("test.a")
2941 .gt(lit(1u32))
2942 .and(col("test.b").lt(col("test2.b")))
2943 .and(col("test2.c").gt(lit(4u32)));
2944 let plan = LogicalPlanBuilder::from(left)
2945 .join(
2946 right,
2947 JoinType::Full,
2948 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2949 Some(filter),
2950 )?
2951 .build()?;
2952
2953 assert_snapshot!(plan,
2955 @r"
2956 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2957 Projection: test.a, test.b, test.c
2958 TableScan: test
2959 Projection: test2.a, test2.b, test2.c
2960 TableScan: test2
2961 ",
2962 );
2963 assert_optimized_plan_equal!(
2964 plan,
2965 @r"
2966 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2967 Projection: test.a, test.b, test.c
2968 TableScan: test
2969 Projection: test2.a, test2.b, test2.c
2970 TableScan: test2
2971 "
2972 )
2973 }
2974
2975 struct PushDownProvider {
2976 pub filter_support: TableProviderFilterPushDown,
2977 }
2978
2979 #[async_trait]
2980 impl TableSource for PushDownProvider {
2981 fn schema(&self) -> SchemaRef {
2982 Arc::new(Schema::new(vec![
2983 Field::new("a", DataType::Int32, true),
2984 Field::new("b", DataType::Int32, true),
2985 ]))
2986 }
2987
2988 fn table_type(&self) -> TableType {
2989 TableType::Base
2990 }
2991
2992 fn supports_filters_pushdown(
2993 &self,
2994 filters: &[&Expr],
2995 ) -> Result<Vec<TableProviderFilterPushDown>> {
2996 Ok((0..filters.len())
2997 .map(|_| self.filter_support.clone())
2998 .collect())
2999 }
3000
3001 fn as_any(&self) -> &dyn Any {
3002 self
3003 }
3004 }
3005
3006 fn table_scan_with_pushdown_provider_builder(
3007 filter_support: TableProviderFilterPushDown,
3008 filters: Vec<Expr>,
3009 projection: Option<Vec<usize>>,
3010 ) -> Result<LogicalPlanBuilder> {
3011 let test_provider = PushDownProvider { filter_support };
3012
3013 let table_scan = LogicalPlan::TableScan(TableScan {
3014 table_name: "test".into(),
3015 filters,
3016 projected_schema: Arc::new(DFSchema::try_from(
3017 (*test_provider.schema()).clone(),
3018 )?),
3019 projection,
3020 source: Arc::new(test_provider),
3021 fetch: None,
3022 });
3023
3024 Ok(LogicalPlanBuilder::from(table_scan))
3025 }
3026
3027 fn table_scan_with_pushdown_provider(
3028 filter_support: TableProviderFilterPushDown,
3029 ) -> Result<LogicalPlan> {
3030 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3031 .filter(col("a").eq(lit(1i64)))?
3032 .build()
3033 }
3034
3035 #[test]
3036 fn filter_with_table_provider_exact() -> Result<()> {
3037 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3038
3039 assert_optimized_plan_equal!(
3040 plan,
3041 @"TableScan: test, full_filters=[a = Int64(1)]"
3042 )
3043 }
3044
3045 #[test]
3046 fn filter_with_table_provider_inexact() -> Result<()> {
3047 let plan =
3048 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3049
3050 assert_optimized_plan_equal!(
3051 plan,
3052 @r"
3053 Filter: a = Int64(1)
3054 TableScan: test, partial_filters=[a = Int64(1)]
3055 "
3056 )
3057 }
3058
3059 #[test]
3060 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3061 let plan =
3062 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3063
3064 let optimized_plan = PushDownFilter::new()
3065 .rewrite(plan, &OptimizerContext::new())
3066 .expect("failed to optimize plan")
3067 .data;
3068
3069 assert_optimized_plan_equal!(
3072 optimized_plan,
3073 @r"
3074 Filter: a = Int64(1)
3075 TableScan: test, partial_filters=[a = Int64(1)]
3076 "
3077 )
3078 }
3079
3080 #[test]
3081 fn filter_with_table_provider_unsupported() -> Result<()> {
3082 let plan =
3083 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3084
3085 assert_optimized_plan_equal!(
3086 plan,
3087 @r"
3088 Filter: a = Int64(1)
3089 TableScan: test
3090 "
3091 )
3092 }
3093
3094 #[test]
3095 fn multi_combined_filter() -> Result<()> {
3096 let plan = table_scan_with_pushdown_provider_builder(
3097 TableProviderFilterPushDown::Inexact,
3098 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3099 Some(vec![0]),
3100 )?
3101 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3102 .project(vec![col("a"), col("b")])?
3103 .build()?;
3104
3105 assert_optimized_plan_equal!(
3106 plan,
3107 @r"
3108 Projection: a, b
3109 Filter: a = Int64(10) AND b > Int64(11)
3110 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3111 "
3112 )
3113 }
3114
3115 #[test]
3116 fn multi_combined_filter_exact() -> Result<()> {
3117 let plan = table_scan_with_pushdown_provider_builder(
3118 TableProviderFilterPushDown::Exact,
3119 vec![],
3120 Some(vec![0]),
3121 )?
3122 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3123 .project(vec![col("a"), col("b")])?
3124 .build()?;
3125
3126 assert_optimized_plan_equal!(
3127 plan,
3128 @r"
3129 Projection: a, b
3130 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3131 "
3132 )
3133 }
3134
3135 #[test]
3136 fn test_filter_with_alias() -> Result<()> {
3137 let table_scan = test_table_scan()?;
3141 let plan = LogicalPlanBuilder::from(table_scan)
3142 .project(vec![col("a").alias("b"), col("c")])?
3143 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3144 .build()?;
3145
3146 assert_snapshot!(plan,
3148 @r"
3149 Filter: b > Int64(10) AND test.c > Int64(10)
3150 Projection: test.a AS b, test.c
3151 TableScan: test
3152 ",
3153 );
3154 assert_optimized_plan_equal!(
3156 plan,
3157 @r"
3158 Projection: test.a AS b, test.c
3159 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3160 "
3161 )
3162 }
3163
3164 #[test]
3165 fn test_filter_with_alias_2() -> Result<()> {
3166 let table_scan = test_table_scan()?;
3170 let plan = LogicalPlanBuilder::from(table_scan)
3171 .project(vec![col("a").alias("b"), col("c")])?
3172 .project(vec![col("b"), col("c")])?
3173 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3174 .build()?;
3175
3176 assert_snapshot!(plan,
3178 @r"
3179 Filter: b > Int64(10) AND test.c > Int64(10)
3180 Projection: b, test.c
3181 Projection: test.a AS b, test.c
3182 TableScan: test
3183 ",
3184 );
3185 assert_optimized_plan_equal!(
3187 plan,
3188 @r"
3189 Projection: b, test.c
3190 Projection: test.a AS b, test.c
3191 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3192 "
3193 )
3194 }
3195
3196 #[test]
3197 fn test_filter_with_multi_alias() -> Result<()> {
3198 let table_scan = test_table_scan()?;
3199 let plan = LogicalPlanBuilder::from(table_scan)
3200 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3201 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3202 .build()?;
3203
3204 assert_snapshot!(plan,
3206 @r"
3207 Filter: b > Int64(10) AND d > Int64(10)
3208 Projection: test.a AS b, test.c AS d
3209 TableScan: test
3210 ",
3211 );
3212 assert_optimized_plan_equal!(
3214 plan,
3215 @r"
3216 Projection: test.a AS b, test.c AS d
3217 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3218 "
3219 )
3220 }
3221
3222 #[test]
3224 fn join_filter_with_alias() -> Result<()> {
3225 let table_scan = test_table_scan()?;
3226 let left = LogicalPlanBuilder::from(table_scan)
3227 .project(vec![col("a").alias("c")])?
3228 .build()?;
3229 let right_table_scan = test_table_scan_with_name("test2")?;
3230 let right = LogicalPlanBuilder::from(right_table_scan)
3231 .project(vec![col("b").alias("d")])?
3232 .build()?;
3233 let filter = col("c").gt(lit(1u32));
3234 let plan = LogicalPlanBuilder::from(left)
3235 .join(
3236 right,
3237 JoinType::Inner,
3238 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3239 Some(filter),
3240 )?
3241 .build()?;
3242
3243 assert_snapshot!(plan,
3244 @r"
3245 Inner Join: c = d Filter: c > UInt32(1)
3246 Projection: test.a AS c
3247 TableScan: test
3248 Projection: test2.b AS d
3249 TableScan: test2
3250 ",
3251 );
3252 assert_optimized_plan_equal!(
3254 plan,
3255 @r"
3256 Inner Join: c = d
3257 Projection: test.a AS c
3258 TableScan: test, full_filters=[test.a > UInt32(1)]
3259 Projection: test2.b AS d
3260 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3261 "
3262 )
3263 }
3264
3265 #[test]
3266 fn test_in_filter_with_alias() -> Result<()> {
3267 let table_scan = test_table_scan()?;
3271 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3272 let plan = LogicalPlanBuilder::from(table_scan)
3273 .project(vec![col("a").alias("b"), col("c")])?
3274 .filter(in_list(col("b"), filter_value, false))?
3275 .build()?;
3276
3277 assert_snapshot!(plan,
3279 @r"
3280 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3281 Projection: test.a AS b, test.c
3282 TableScan: test
3283 ",
3284 );
3285 assert_optimized_plan_equal!(
3287 plan,
3288 @r"
3289 Projection: test.a AS b, test.c
3290 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3291 "
3292 )
3293 }
3294
3295 #[test]
3296 fn test_in_filter_with_alias_2() -> Result<()> {
3297 let table_scan = test_table_scan()?;
3301 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3302 let plan = LogicalPlanBuilder::from(table_scan)
3303 .project(vec![col("a").alias("b"), col("c")])?
3304 .project(vec![col("b"), col("c")])?
3305 .filter(in_list(col("b"), filter_value, false))?
3306 .build()?;
3307
3308 assert_snapshot!(plan,
3310 @r"
3311 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3312 Projection: b, test.c
3313 Projection: test.a AS b, test.c
3314 TableScan: test
3315 ",
3316 );
3317 assert_optimized_plan_equal!(
3319 plan,
3320 @r"
3321 Projection: b, test.c
3322 Projection: test.a AS b, test.c
3323 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3324 "
3325 )
3326 }
3327
3328 #[test]
3329 fn test_in_subquery_with_alias() -> Result<()> {
3330 let table_scan = test_table_scan()?;
3333 let table_scan_sq = test_table_scan_with_name("sq")?;
3334 let subplan = Arc::new(
3335 LogicalPlanBuilder::from(table_scan_sq)
3336 .project(vec![col("c")])?
3337 .build()?,
3338 );
3339 let plan = LogicalPlanBuilder::from(table_scan)
3340 .project(vec![col("a").alias("b"), col("c")])?
3341 .filter(in_subquery(col("b"), subplan))?
3342 .build()?;
3343
3344 assert_snapshot!(plan,
3346 @r"
3347 Filter: b IN (<subquery>)
3348 Subquery:
3349 Projection: sq.c
3350 TableScan: sq
3351 Projection: test.a AS b, test.c
3352 TableScan: test
3353 ",
3354 );
3355 assert_optimized_plan_equal!(
3357 plan,
3358 @r"
3359 Projection: test.a AS b, test.c
3360 TableScan: test, full_filters=[test.a IN (<subquery>)]
3361 Subquery:
3362 Projection: sq.c
3363 TableScan: sq
3364 "
3365 )
3366 }
3367
3368 #[test]
3369 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3370 let plan = LogicalPlanBuilder::empty(true)
3372 .project(vec![lit(0i64).alias("a")])?
3373 .alias("b")?
3374 .project(vec![col("b.a")])?
3375 .alias("b")?
3376 .filter(col("b.a").eq(lit(1i64)))?
3377 .project(vec![col("b.a")])?
3378 .build()?;
3379
3380 assert_snapshot!(plan,
3381 @r"
3382 Projection: b.a
3383 Filter: b.a = Int64(1)
3384 SubqueryAlias: b
3385 Projection: b.a
3386 SubqueryAlias: b
3387 Projection: Int64(0) AS a
3388 EmptyRelation
3389 ",
3390 );
3391 assert_optimized_plan_equal!(
3394 plan,
3395 @r"
3396 Projection: b.a
3397 SubqueryAlias: b
3398 Projection: b.a
3399 SubqueryAlias: b
3400 Projection: Int64(0) AS a
3401 Filter: Int64(0) = Int64(1)
3402 EmptyRelation
3403 "
3404 )
3405 }
3406
3407 #[test]
3408 fn test_crossjoin_with_or_clause() -> Result<()> {
3409 let table_scan = test_table_scan()?;
3411 let left = LogicalPlanBuilder::from(table_scan)
3412 .project(vec![col("a"), col("b"), col("c")])?
3413 .build()?;
3414 let right_table_scan = test_table_scan_with_name("test1")?;
3415 let right = LogicalPlanBuilder::from(right_table_scan)
3416 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3417 .build()?;
3418 let filter = or(
3419 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3420 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3421 );
3422 let plan = LogicalPlanBuilder::from(left)
3423 .cross_join(right)?
3424 .filter(filter)?
3425 .build()?;
3426
3427 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3428 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3429 Projection: test.a, test.b, test.c
3430 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3431 Projection: test1.a AS d, test1.a AS e
3432 TableScan: test1
3433 ")?;
3434
3435 let optimized_plan = PushDownFilter::new()
3438 .rewrite(plan, &OptimizerContext::new())
3439 .expect("failed to optimize plan")
3440 .data;
3441 assert_optimized_plan_equal!(
3442 optimized_plan,
3443 @r"
3444 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3445 Projection: test.a, test.b, test.c
3446 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3447 Projection: test1.a AS d, test1.a AS e
3448 TableScan: test1
3449 "
3450 )
3451 }
3452
3453 #[test]
3454 fn left_semi_join() -> Result<()> {
3455 let left = test_table_scan_with_name("test1")?;
3456 let right_table_scan = test_table_scan_with_name("test2")?;
3457 let right = LogicalPlanBuilder::from(right_table_scan)
3458 .project(vec![col("a"), col("b")])?
3459 .build()?;
3460 let plan = LogicalPlanBuilder::from(left)
3461 .join(
3462 right,
3463 JoinType::LeftSemi,
3464 (
3465 vec![Column::from_qualified_name("test1.a")],
3466 vec![Column::from_qualified_name("test2.a")],
3467 ),
3468 None,
3469 )?
3470 .filter(col("test2.a").lt_eq(lit(1i64)))?
3471 .build()?;
3472
3473 assert_snapshot!(plan,
3475 @r"
3476 Filter: test2.a <= Int64(1)
3477 LeftSemi Join: test1.a = test2.a
3478 TableScan: test1
3479 Projection: test2.a, test2.b
3480 TableScan: test2
3481 ",
3482 );
3483 assert_optimized_plan_equal!(
3485 plan,
3486 @r"
3487 Filter: test2.a <= Int64(1)
3488 LeftSemi Join: test1.a = test2.a
3489 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3490 Projection: test2.a, test2.b
3491 TableScan: test2
3492 "
3493 )
3494 }
3495
3496 #[test]
3497 fn left_semi_join_with_filters() -> Result<()> {
3498 let left = test_table_scan_with_name("test1")?;
3499 let right_table_scan = test_table_scan_with_name("test2")?;
3500 let right = LogicalPlanBuilder::from(right_table_scan)
3501 .project(vec![col("a"), col("b")])?
3502 .build()?;
3503 let plan = LogicalPlanBuilder::from(left)
3504 .join(
3505 right,
3506 JoinType::LeftSemi,
3507 (
3508 vec![Column::from_qualified_name("test1.a")],
3509 vec![Column::from_qualified_name("test2.a")],
3510 ),
3511 Some(
3512 col("test1.b")
3513 .gt(lit(1u32))
3514 .and(col("test2.b").gt(lit(2u32))),
3515 ),
3516 )?
3517 .build()?;
3518
3519 assert_snapshot!(plan,
3521 @r"
3522 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3523 TableScan: test1
3524 Projection: test2.a, test2.b
3525 TableScan: test2
3526 ",
3527 );
3528 assert_optimized_plan_equal!(
3530 plan,
3531 @r"
3532 LeftSemi Join: test1.a = test2.a
3533 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3534 Projection: test2.a, test2.b
3535 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3536 "
3537 )
3538 }
3539
3540 #[test]
3541 fn right_semi_join() -> Result<()> {
3542 let left = test_table_scan_with_name("test1")?;
3543 let right_table_scan = test_table_scan_with_name("test2")?;
3544 let right = LogicalPlanBuilder::from(right_table_scan)
3545 .project(vec![col("a"), col("b")])?
3546 .build()?;
3547 let plan = LogicalPlanBuilder::from(left)
3548 .join(
3549 right,
3550 JoinType::RightSemi,
3551 (
3552 vec![Column::from_qualified_name("test1.a")],
3553 vec![Column::from_qualified_name("test2.a")],
3554 ),
3555 None,
3556 )?
3557 .filter(col("test1.a").lt_eq(lit(1i64)))?
3558 .build()?;
3559
3560 assert_snapshot!(plan,
3562 @r"
3563 Filter: test1.a <= Int64(1)
3564 RightSemi Join: test1.a = test2.a
3565 TableScan: test1
3566 Projection: test2.a, test2.b
3567 TableScan: test2
3568 ",
3569 );
3570 assert_optimized_plan_equal!(
3572 plan,
3573 @r"
3574 Filter: test1.a <= Int64(1)
3575 RightSemi Join: test1.a = test2.a
3576 TableScan: test1
3577 Projection: test2.a, test2.b
3578 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3579 "
3580 )
3581 }
3582
3583 #[test]
3584 fn right_semi_join_with_filters() -> Result<()> {
3585 let left = test_table_scan_with_name("test1")?;
3586 let right_table_scan = test_table_scan_with_name("test2")?;
3587 let right = LogicalPlanBuilder::from(right_table_scan)
3588 .project(vec![col("a"), col("b")])?
3589 .build()?;
3590 let plan = LogicalPlanBuilder::from(left)
3591 .join(
3592 right,
3593 JoinType::RightSemi,
3594 (
3595 vec![Column::from_qualified_name("test1.a")],
3596 vec![Column::from_qualified_name("test2.a")],
3597 ),
3598 Some(
3599 col("test1.b")
3600 .gt(lit(1u32))
3601 .and(col("test2.b").gt(lit(2u32))),
3602 ),
3603 )?
3604 .build()?;
3605
3606 assert_snapshot!(plan,
3608 @r"
3609 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3610 TableScan: test1
3611 Projection: test2.a, test2.b
3612 TableScan: test2
3613 ",
3614 );
3615 assert_optimized_plan_equal!(
3617 plan,
3618 @r"
3619 RightSemi Join: test1.a = test2.a
3620 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3621 Projection: test2.a, test2.b
3622 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3623 "
3624 )
3625 }
3626
3627 #[test]
3628 fn left_anti_join() -> Result<()> {
3629 let table_scan = test_table_scan_with_name("test1")?;
3630 let left = LogicalPlanBuilder::from(table_scan)
3631 .project(vec![col("a"), col("b")])?
3632 .build()?;
3633 let right_table_scan = test_table_scan_with_name("test2")?;
3634 let right = LogicalPlanBuilder::from(right_table_scan)
3635 .project(vec![col("a"), col("b")])?
3636 .build()?;
3637 let plan = LogicalPlanBuilder::from(left)
3638 .join(
3639 right,
3640 JoinType::LeftAnti,
3641 (
3642 vec![Column::from_qualified_name("test1.a")],
3643 vec![Column::from_qualified_name("test2.a")],
3644 ),
3645 None,
3646 )?
3647 .filter(col("test2.a").gt(lit(2u32)))?
3648 .build()?;
3649
3650 assert_snapshot!(plan,
3652 @r"
3653 Filter: test2.a > UInt32(2)
3654 LeftAnti Join: test1.a = test2.a
3655 Projection: test1.a, test1.b
3656 TableScan: test1
3657 Projection: test2.a, test2.b
3658 TableScan: test2
3659 ",
3660 );
3661 assert_optimized_plan_equal!(
3663 plan,
3664 @r"
3665 Filter: test2.a > UInt32(2)
3666 LeftAnti Join: test1.a = test2.a
3667 Projection: test1.a, test1.b
3668 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3669 Projection: test2.a, test2.b
3670 TableScan: test2
3671 "
3672 )
3673 }
3674
3675 #[test]
3676 fn left_anti_join_with_filters() -> Result<()> {
3677 let table_scan = test_table_scan_with_name("test1")?;
3678 let left = LogicalPlanBuilder::from(table_scan)
3679 .project(vec![col("a"), col("b")])?
3680 .build()?;
3681 let right_table_scan = test_table_scan_with_name("test2")?;
3682 let right = LogicalPlanBuilder::from(right_table_scan)
3683 .project(vec![col("a"), col("b")])?
3684 .build()?;
3685 let plan = LogicalPlanBuilder::from(left)
3686 .join(
3687 right,
3688 JoinType::LeftAnti,
3689 (
3690 vec![Column::from_qualified_name("test1.a")],
3691 vec![Column::from_qualified_name("test2.a")],
3692 ),
3693 Some(
3694 col("test1.b")
3695 .gt(lit(1u32))
3696 .and(col("test2.b").gt(lit(2u32))),
3697 ),
3698 )?
3699 .build()?;
3700
3701 assert_snapshot!(plan,
3703 @r"
3704 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3705 Projection: test1.a, test1.b
3706 TableScan: test1
3707 Projection: test2.a, test2.b
3708 TableScan: test2
3709 ",
3710 );
3711 assert_optimized_plan_equal!(
3713 plan,
3714 @r"
3715 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3716 Projection: test1.a, test1.b
3717 TableScan: test1
3718 Projection: test2.a, test2.b
3719 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3720 "
3721 )
3722 }
3723
3724 #[test]
3725 fn right_anti_join() -> Result<()> {
3726 let table_scan = test_table_scan_with_name("test1")?;
3727 let left = LogicalPlanBuilder::from(table_scan)
3728 .project(vec![col("a"), col("b")])?
3729 .build()?;
3730 let right_table_scan = test_table_scan_with_name("test2")?;
3731 let right = LogicalPlanBuilder::from(right_table_scan)
3732 .project(vec![col("a"), col("b")])?
3733 .build()?;
3734 let plan = LogicalPlanBuilder::from(left)
3735 .join(
3736 right,
3737 JoinType::RightAnti,
3738 (
3739 vec![Column::from_qualified_name("test1.a")],
3740 vec![Column::from_qualified_name("test2.a")],
3741 ),
3742 None,
3743 )?
3744 .filter(col("test1.a").gt(lit(2u32)))?
3745 .build()?;
3746
3747 assert_snapshot!(plan,
3749 @r"
3750 Filter: test1.a > UInt32(2)
3751 RightAnti Join: test1.a = test2.a
3752 Projection: test1.a, test1.b
3753 TableScan: test1
3754 Projection: test2.a, test2.b
3755 TableScan: test2
3756 ",
3757 );
3758 assert_optimized_plan_equal!(
3760 plan,
3761 @r"
3762 Filter: test1.a > UInt32(2)
3763 RightAnti Join: test1.a = test2.a
3764 Projection: test1.a, test1.b
3765 TableScan: test1
3766 Projection: test2.a, test2.b
3767 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3768 "
3769 )
3770 }
3771
3772 #[test]
3773 fn right_anti_join_with_filters() -> Result<()> {
3774 let table_scan = test_table_scan_with_name("test1")?;
3775 let left = LogicalPlanBuilder::from(table_scan)
3776 .project(vec![col("a"), col("b")])?
3777 .build()?;
3778 let right_table_scan = test_table_scan_with_name("test2")?;
3779 let right = LogicalPlanBuilder::from(right_table_scan)
3780 .project(vec![col("a"), col("b")])?
3781 .build()?;
3782 let plan = LogicalPlanBuilder::from(left)
3783 .join(
3784 right,
3785 JoinType::RightAnti,
3786 (
3787 vec![Column::from_qualified_name("test1.a")],
3788 vec![Column::from_qualified_name("test2.a")],
3789 ),
3790 Some(
3791 col("test1.b")
3792 .gt(lit(1u32))
3793 .and(col("test2.b").gt(lit(2u32))),
3794 ),
3795 )?
3796 .build()?;
3797
3798 assert_snapshot!(plan,
3800 @r"
3801 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3802 Projection: test1.a, test1.b
3803 TableScan: test1
3804 Projection: test2.a, test2.b
3805 TableScan: test2
3806 ",
3807 );
3808 assert_optimized_plan_equal!(
3810 plan,
3811 @r"
3812 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3813 Projection: test1.a, test1.b
3814 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3815 Projection: test2.a, test2.b
3816 TableScan: test2
3817 "
3818 )
3819 }
3820
3821 #[derive(Debug)]
3822 struct TestScalarUDF {
3823 signature: Signature,
3824 }
3825
3826 impl ScalarUDFImpl for TestScalarUDF {
3827 fn as_any(&self) -> &dyn Any {
3828 self
3829 }
3830 fn name(&self) -> &str {
3831 "TestScalarUDF"
3832 }
3833
3834 fn signature(&self) -> &Signature {
3835 &self.signature
3836 }
3837
3838 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3839 Ok(DataType::Int32)
3840 }
3841
3842 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3843 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3844 }
3845 }
3846
3847 #[test]
3848 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3849 let table_scan = test_table_scan_with_name("test1")?;
3851 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3852 signature: Signature::exact(vec![], Volatility::Volatile),
3853 });
3854 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3855
3856 let plan = LogicalPlanBuilder::from(table_scan)
3857 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3858 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3859 .alias("t")?
3860 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3861 .project(vec![col("t.a"), col("t.r")])?
3862 .build()?;
3863
3864 assert_snapshot!(plan,
3865 @r"
3866 Projection: t.a, t.r
3867 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3868 SubqueryAlias: t
3869 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3870 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3871 TableScan: test1
3872 ",
3873 );
3874 assert_optimized_plan_equal!(
3875 plan,
3876 @r"
3877 Projection: t.a, t.r
3878 SubqueryAlias: t
3879 Filter: r > Float64(0.5)
3880 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3881 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3882 TableScan: test1, full_filters=[test1.a > Int32(5)]
3883 "
3884 )
3885 }
3886
3887 #[test]
3888 fn test_push_down_volatile_function_in_join() -> Result<()> {
3889 let table_scan = test_table_scan_with_name("test1")?;
3891 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3892 signature: Signature::exact(vec![], Volatility::Volatile),
3893 });
3894 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3895 let left = LogicalPlanBuilder::from(table_scan).build()?;
3896 let right_table_scan = test_table_scan_with_name("test2")?;
3897 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3898 let plan = LogicalPlanBuilder::from(left)
3899 .join(
3900 right,
3901 JoinType::Inner,
3902 (
3903 vec![Column::from_qualified_name("test1.a")],
3904 vec![Column::from_qualified_name("test2.a")],
3905 ),
3906 None,
3907 )?
3908 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
3909 .alias("t")?
3910 .filter(col("t.r").gt(lit(0.8)))?
3911 .project(vec![col("t.a"), col("t.r")])?
3912 .build()?;
3913
3914 assert_snapshot!(plan,
3915 @r"
3916 Projection: t.a, t.r
3917 Filter: t.r > Float64(0.8)
3918 SubqueryAlias: t
3919 Projection: test1.a AS a, TestScalarUDF() AS r
3920 Inner Join: test1.a = test2.a
3921 TableScan: test1
3922 TableScan: test2
3923 ",
3924 );
3925 assert_optimized_plan_equal!(
3926 plan,
3927 @r"
3928 Projection: t.a, t.r
3929 SubqueryAlias: t
3930 Filter: r > Float64(0.8)
3931 Projection: test1.a AS a, TestScalarUDF() AS r
3932 Inner Join: test1.a = test2.a
3933 TableScan: test1
3934 TableScan: test2
3935 "
3936 )
3937 }
3938
3939 #[test]
3940 fn test_push_down_volatile_table_scan() -> Result<()> {
3941 let table_scan = test_table_scan()?;
3943 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3944 signature: Signature::exact(vec![], Volatility::Volatile),
3945 });
3946 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3947 let plan = LogicalPlanBuilder::from(table_scan)
3948 .project(vec![col("a"), col("b")])?
3949 .filter(expr.gt(lit(0.1)))?
3950 .build()?;
3951
3952 assert_snapshot!(plan,
3953 @r"
3954 Filter: TestScalarUDF() > Float64(0.1)
3955 Projection: test.a, test.b
3956 TableScan: test
3957 ",
3958 );
3959 assert_optimized_plan_equal!(
3960 plan,
3961 @r"
3962 Projection: test.a, test.b
3963 Filter: TestScalarUDF() > Float64(0.1)
3964 TableScan: test
3965 "
3966 )
3967 }
3968
3969 #[test]
3970 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
3971 let table_scan = test_table_scan()?;
3973 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3974 signature: Signature::exact(vec![], Volatility::Volatile),
3975 });
3976 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3977 let plan = LogicalPlanBuilder::from(table_scan)
3978 .project(vec![col("a"), col("b")])?
3979 .filter(
3980 expr.gt(lit(0.1))
3981 .and(col("t.a").gt(lit(5)))
3982 .and(col("t.b").gt(lit(10))),
3983 )?
3984 .build()?;
3985
3986 assert_snapshot!(plan,
3987 @r"
3988 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
3989 Projection: test.a, test.b
3990 TableScan: test
3991 ",
3992 );
3993 assert_optimized_plan_equal!(
3994 plan,
3995 @r"
3996 Projection: test.a, test.b
3997 Filter: TestScalarUDF() > Float64(0.1)
3998 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
3999 "
4000 )
4001 }
4002
4003 #[test]
4004 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4005 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4007 signature: Signature::exact(vec![], Volatility::Volatile),
4008 });
4009 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4010 let plan = table_scan_with_pushdown_provider_builder(
4011 TableProviderFilterPushDown::Unsupported,
4012 vec![],
4013 None,
4014 )?
4015 .project(vec![col("a"), col("b")])?
4016 .filter(
4017 expr.gt(lit(0.1))
4018 .and(col("t.a").gt(lit(5)))
4019 .and(col("t.b").gt(lit(10))),
4020 )?
4021 .build()?;
4022
4023 assert_snapshot!(plan,
4024 @r"
4025 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4026 Projection: a, b
4027 TableScan: test
4028 ",
4029 );
4030 assert_optimized_plan_equal!(
4031 plan,
4032 @r"
4033 Projection: a, b
4034 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4035 TableScan: test
4036 "
4037 )
4038 }
4039
4040 #[test]
4041 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4042 #[derive(Debug, Hash, Eq, PartialEq)]
4044 struct TestUserNode {
4045 schema: DFSchemaRef,
4046 }
4047
4048 impl PartialOrd for TestUserNode {
4049 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4050 None
4051 }
4052 }
4053
4054 impl TestUserNode {
4055 fn new() -> Self {
4056 let schema = Arc::new(
4057 DFSchema::new_with_metadata(
4058 vec![(None, Field::new("a", DataType::Int64, false).into())],
4059 Default::default(),
4060 )
4061 .unwrap(),
4062 );
4063
4064 Self { schema }
4065 }
4066 }
4067
4068 impl UserDefinedLogicalNodeCore for TestUserNode {
4069 fn name(&self) -> &str {
4070 "test_node"
4071 }
4072
4073 fn inputs(&self) -> Vec<&LogicalPlan> {
4074 vec![]
4075 }
4076
4077 fn schema(&self) -> &DFSchemaRef {
4078 &self.schema
4079 }
4080
4081 fn expressions(&self) -> Vec<Expr> {
4082 vec![]
4083 }
4084
4085 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4086 write!(f, "TestUserNode")
4087 }
4088
4089 fn with_exprs_and_inputs(
4090 &self,
4091 exprs: Vec<Expr>,
4092 inputs: Vec<LogicalPlan>,
4093 ) -> Result<Self> {
4094 assert!(exprs.is_empty());
4095 assert!(inputs.is_empty());
4096 Ok(Self {
4097 schema: Arc::clone(&self.schema),
4098 })
4099 }
4100 }
4101
4102 let node = LogicalPlan::Extension(Extension {
4104 node: Arc::new(TestUserNode::new()),
4105 });
4106
4107 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4108
4109 assert_snapshot!(plan,
4111 @r"
4112 Filter: Boolean(false)
4113 TestUserNode
4114 ",
4115 );
4116 assert_optimized_plan_equal!(
4118 plan,
4119 @r"
4120 Filter: Boolean(false)
4121 TestUserNode
4122 "
4123 )
4124 }
4125}