1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use indexmap::IndexSet;
24use itertools::Itertools;
25
26use datafusion_common::tree_node::{
27 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
28};
29use datafusion_common::{
30 internal_err, plan_err, qualified_name, Column, DFSchema, Result,
31};
32use datafusion_expr::expr::WindowFunction;
33use datafusion_expr::expr_rewriter::replace_col;
34use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
35use datafusion_expr::utils::{
36 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
37};
38use datafusion_expr::{
39 and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
40};
41
42use crate::optimizer::ApplyOrder;
43use crate::simplify_expressions::simplify_predicates;
44use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
45use crate::{OptimizerConfig, OptimizerRule};
46
47#[derive(Default, Debug)]
135pub struct PushDownFilter {}
136
137pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
162 match join_type {
163 JoinType::Inner => (true, true),
164 JoinType::Left => (true, false),
165 JoinType::Right => (false, true),
166 JoinType::Full => (false, false),
167 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
170 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
173 }
174}
175
176pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
186 match join_type {
187 JoinType::Inner => (true, true),
188 JoinType::Left => (false, true),
189 JoinType::Right => (true, false),
190 JoinType::Full => (false, false),
191 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
192 JoinType::LeftAnti => (false, true),
193 JoinType::RightAnti => (true, false),
194 JoinType::LeftMark => (false, true),
195 JoinType::RightMark => (true, false),
196 }
197}
198
199#[derive(Debug)]
202struct ColumnChecker<'a> {
203 left_schema: &'a DFSchema,
205 left_columns: Option<HashSet<Column>>,
207 right_schema: &'a DFSchema,
209 right_columns: Option<HashSet<Column>>,
211}
212
213impl<'a> ColumnChecker<'a> {
214 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
215 Self {
216 left_schema,
217 left_columns: None,
218 right_schema,
219 right_columns: None,
220 }
221 }
222
223 fn is_left_only(&mut self, predicate: &Expr) -> bool {
225 if self.left_columns.is_none() {
226 self.left_columns = Some(schema_columns(self.left_schema));
227 }
228 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
229 }
230
231 fn is_right_only(&mut self, predicate: &Expr) -> bool {
233 if self.right_columns.is_none() {
234 self.right_columns = Some(schema_columns(self.right_schema));
235 }
236 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
237 }
238}
239
240fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
242 schema
243 .iter()
244 .flat_map(|(qualifier, field)| {
245 [
246 Column::new(qualifier.cloned(), field.name()),
247 Column::new_unqualified(field.name()),
249 ]
250 })
251 .collect::<HashSet<_>>()
252}
253
254fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
256 let mut is_evaluate = true;
257 predicate.apply(|expr| match expr {
258 Expr::Column(_)
259 | Expr::Literal(_, _)
260 | Expr::Placeholder(_)
261 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
262 Expr::Exists { .. }
263 | Expr::InSubquery(_)
264 | Expr::ScalarSubquery(_)
265 | Expr::OuterReferenceColumn(_, _)
266 | Expr::Unnest(_) => {
267 is_evaluate = false;
268 Ok(TreeNodeRecursion::Stop)
269 }
270 Expr::Alias(_)
271 | Expr::BinaryExpr(_)
272 | Expr::Like(_)
273 | Expr::SimilarTo(_)
274 | Expr::Not(_)
275 | Expr::IsNotNull(_)
276 | Expr::IsNull(_)
277 | Expr::IsTrue(_)
278 | Expr::IsFalse(_)
279 | Expr::IsUnknown(_)
280 | Expr::IsNotTrue(_)
281 | Expr::IsNotFalse(_)
282 | Expr::IsNotUnknown(_)
283 | Expr::Negative(_)
284 | Expr::Between(_)
285 | Expr::Case(_)
286 | Expr::Cast(_)
287 | Expr::TryCast(_)
288 | Expr::InList { .. }
289 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
290 #[expect(deprecated)]
292 Expr::AggregateFunction(_)
293 | Expr::WindowFunction(_)
294 | Expr::Wildcard { .. }
295 | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
296 })?;
297 Ok(is_evaluate)
298}
299
300fn extract_or_clauses_for_join<'a>(
334 filters: &'a [Expr],
335 schema: &'a DFSchema,
336) -> impl Iterator<Item = Expr> + 'a {
337 let schema_columns = schema_columns(schema);
338
339 filters.iter().filter_map(move |expr| {
341 if let Expr::BinaryExpr(BinaryExpr {
342 left,
343 op: Operator::Or,
344 right,
345 }) = expr
346 {
347 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
348 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
349
350 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
352 return Some(or(left_expr, right_expr));
353 }
354 }
355 None
356 })
357}
358
359fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
371 let mut predicate = None;
372
373 match expr {
374 Expr::BinaryExpr(BinaryExpr {
375 left: l_expr,
376 op: Operator::Or,
377 right: r_expr,
378 }) => {
379 let l_expr = extract_or_clause(l_expr, schema_columns);
380 let r_expr = extract_or_clause(r_expr, schema_columns);
381
382 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
383 predicate = Some(or(l_expr, r_expr));
384 }
385 }
386 Expr::BinaryExpr(BinaryExpr {
387 left: l_expr,
388 op: Operator::And,
389 right: r_expr,
390 }) => {
391 let l_expr = extract_or_clause(l_expr, schema_columns);
392 let r_expr = extract_or_clause(r_expr, schema_columns);
393
394 match (l_expr, r_expr) {
395 (Some(l_expr), Some(r_expr)) => {
396 predicate = Some(and(l_expr, r_expr));
397 }
398 (Some(l_expr), None) => {
399 predicate = Some(l_expr);
400 }
401 (None, Some(r_expr)) => {
402 predicate = Some(r_expr);
403 }
404 (None, None) => {
405 predicate = None;
406 }
407 }
408 }
409 _ => {
410 if has_all_column_refs(expr, schema_columns) {
411 predicate = Some(expr.clone());
412 }
413 }
414 }
415
416 predicate
417}
418
419fn push_down_all_join(
421 predicates: Vec<Expr>,
422 inferred_join_predicates: Vec<Expr>,
423 mut join: Join,
424 on_filter: Vec<Expr>,
425) -> Result<Transformed<LogicalPlan>> {
426 let is_inner_join = join.join_type == JoinType::Inner;
427 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
429
430 let left_schema = join.left.schema();
435 let right_schema = join.right.schema();
436 let mut left_push = vec![];
437 let mut right_push = vec![];
438 let mut keep_predicates = vec![];
439 let mut join_conditions = vec![];
440 let mut checker = ColumnChecker::new(left_schema, right_schema);
441 for predicate in predicates {
442 if left_preserved && checker.is_left_only(&predicate) {
443 left_push.push(predicate);
444 } else if right_preserved && checker.is_right_only(&predicate) {
445 right_push.push(predicate);
446 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
447 join_conditions.push(predicate);
450 } else {
451 keep_predicates.push(predicate);
452 }
453 }
454
455 for predicate in inferred_join_predicates {
457 if left_preserved && checker.is_left_only(&predicate) {
458 left_push.push(predicate);
459 } else if right_preserved && checker.is_right_only(&predicate) {
460 right_push.push(predicate);
461 }
462 }
463
464 let mut on_filter_join_conditions = vec![];
465 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
466
467 if !on_filter.is_empty() {
468 for on in on_filter {
469 if on_left_preserved && checker.is_left_only(&on) {
470 left_push.push(on)
471 } else if on_right_preserved && checker.is_right_only(&on) {
472 right_push.push(on)
473 } else {
474 on_filter_join_conditions.push(on)
475 }
476 }
477 }
478
479 if left_preserved {
482 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
483 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
484 }
485 if right_preserved {
486 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
487 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
488 }
489
490 if on_left_preserved {
493 left_push.extend(extract_or_clauses_for_join(
494 &on_filter_join_conditions,
495 left_schema,
496 ));
497 }
498 if on_right_preserved {
499 right_push.extend(extract_or_clauses_for_join(
500 &on_filter_join_conditions,
501 right_schema,
502 ));
503 }
504
505 if let Some(predicate) = conjunction(left_push) {
506 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
507 }
508 if let Some(predicate) = conjunction(right_push) {
509 join.right =
510 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
511 }
512
513 join_conditions.extend(on_filter_join_conditions);
515 join.filter = conjunction(join_conditions);
516
517 let plan = LogicalPlan::Join(join);
519 let plan = if let Some(predicate) = conjunction(keep_predicates) {
520 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
521 } else {
522 plan
523 };
524 Ok(Transformed::yes(plan))
525}
526
527fn push_down_join(
528 join: Join,
529 parent_predicate: Option<&Expr>,
530) -> Result<Transformed<LogicalPlan>> {
531 let predicates = parent_predicate
533 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
534
535 let on_filters = join
537 .filter
538 .as_ref()
539 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
540
541 let inferred_join_predicates =
543 infer_join_predicates(&join, &predicates, &on_filters)?;
544
545 if on_filters.is_empty()
546 && predicates.is_empty()
547 && inferred_join_predicates.is_empty()
548 {
549 return Ok(Transformed::no(LogicalPlan::Join(join)));
550 }
551
552 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
553}
554
555fn infer_join_predicates(
566 join: &Join,
567 predicates: &[Expr],
568 on_filters: &[Expr],
569) -> Result<Vec<Expr>> {
570 let join_col_keys = join
572 .on
573 .iter()
574 .filter_map(|(l, r)| {
575 let left_col = l.try_as_col()?;
576 let right_col = r.try_as_col()?;
577 Some((left_col, right_col))
578 })
579 .collect::<Vec<_>>();
580
581 let join_type = join.join_type;
582
583 let mut inferred_predicates = InferredPredicates::new(join_type);
584
585 infer_join_predicates_from_predicates(
586 &join_col_keys,
587 predicates,
588 &mut inferred_predicates,
589 )?;
590
591 infer_join_predicates_from_on_filters(
592 &join_col_keys,
593 join_type,
594 on_filters,
595 &mut inferred_predicates,
596 )?;
597
598 Ok(inferred_predicates.predicates)
599}
600
601struct InferredPredicates {
610 predicates: Vec<Expr>,
611 is_inner_join: bool,
612}
613
614impl InferredPredicates {
615 fn new(join_type: JoinType) -> Self {
616 Self {
617 predicates: vec![],
618 is_inner_join: matches!(join_type, JoinType::Inner),
619 }
620 }
621
622 fn try_build_predicate(
623 &mut self,
624 predicate: Expr,
625 replace_map: &HashMap<&Column, &Column>,
626 ) -> Result<()> {
627 if self.is_inner_join
628 || matches!(
629 is_restrict_null_predicate(
630 predicate.clone(),
631 replace_map.keys().cloned()
632 ),
633 Ok(true)
634 )
635 {
636 self.predicates.push(replace_col(predicate, replace_map)?);
637 }
638
639 Ok(())
640 }
641}
642
643fn infer_join_predicates_from_predicates(
653 join_col_keys: &[(&Column, &Column)],
654 predicates: &[Expr],
655 inferred_predicates: &mut InferredPredicates,
656) -> Result<()> {
657 infer_join_predicates_impl::<true, true>(
658 join_col_keys,
659 predicates,
660 inferred_predicates,
661 )
662}
663
664fn infer_join_predicates_from_on_filters(
677 join_col_keys: &[(&Column, &Column)],
678 join_type: JoinType,
679 on_filters: &[Expr],
680 inferred_predicates: &mut InferredPredicates,
681) -> Result<()> {
682 match join_type {
683 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
684 JoinType::Inner => infer_join_predicates_impl::<true, true>(
685 join_col_keys,
686 on_filters,
687 inferred_predicates,
688 ),
689 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
690 infer_join_predicates_impl::<true, false>(
691 join_col_keys,
692 on_filters,
693 inferred_predicates,
694 )
695 }
696 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
697 infer_join_predicates_impl::<false, true>(
698 join_col_keys,
699 on_filters,
700 inferred_predicates,
701 )
702 }
703 }
704}
705
706fn infer_join_predicates_impl<
723 const ENABLE_LEFT_TO_RIGHT: bool,
724 const ENABLE_RIGHT_TO_LEFT: bool,
725>(
726 join_col_keys: &[(&Column, &Column)],
727 input_predicates: &[Expr],
728 inferred_predicates: &mut InferredPredicates,
729) -> Result<()> {
730 for predicate in input_predicates {
731 let mut join_cols_to_replace = HashMap::new();
732
733 for &col in &predicate.column_refs() {
734 for (l, r) in join_col_keys.iter() {
735 if ENABLE_LEFT_TO_RIGHT && col == *l {
736 join_cols_to_replace.insert(col, *r);
737 break;
738 }
739 if ENABLE_RIGHT_TO_LEFT && col == *r {
740 join_cols_to_replace.insert(col, *l);
741 break;
742 }
743 }
744 }
745 if join_cols_to_replace.is_empty() {
746 continue;
747 }
748
749 inferred_predicates
750 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
751 }
752 Ok(())
753}
754
755impl OptimizerRule for PushDownFilter {
756 fn name(&self) -> &str {
757 "push_down_filter"
758 }
759
760 fn apply_order(&self) -> Option<ApplyOrder> {
761 Some(ApplyOrder::TopDown)
762 }
763
764 fn supports_rewrite(&self) -> bool {
765 true
766 }
767
768 fn rewrite(
769 &self,
770 plan: LogicalPlan,
771 _config: &dyn OptimizerConfig,
772 ) -> Result<Transformed<LogicalPlan>> {
773 if let LogicalPlan::Join(join) = plan {
774 return push_down_join(join, None);
775 };
776
777 let plan_schema = Arc::clone(plan.schema());
778
779 let LogicalPlan::Filter(mut filter) = plan else {
780 return Ok(Transformed::no(plan));
781 };
782
783 let predicate = split_conjunction_owned(filter.predicate.clone());
784 let old_predicate_len = predicate.len();
785 let new_predicates = simplify_predicates(predicate)?;
786 if old_predicate_len != new_predicates.len() {
787 let Some(new_predicate) = conjunction(new_predicates) else {
788 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
791 };
792 filter.predicate = new_predicate;
793 }
794
795 match Arc::unwrap_or_clone(filter.input) {
796 LogicalPlan::Filter(child_filter) => {
797 let parents_predicates = split_conjunction_owned(filter.predicate);
798
799 let child_predicates = split_conjunction_owned(child_filter.predicate);
801 let new_predicates = parents_predicates
802 .into_iter()
803 .chain(child_predicates)
804 .collect::<IndexSet<_>>()
806 .into_iter()
807 .collect::<Vec<_>>();
808
809 let Some(new_predicate) = conjunction(new_predicates) else {
810 return plan_err!("at least one expression exists");
811 };
812 let new_filter = LogicalPlan::Filter(Filter::try_new(
813 new_predicate,
814 child_filter.input,
815 )?);
816 #[allow(clippy::used_underscore_binding)]
817 self.rewrite(new_filter, _config)
818 }
819 LogicalPlan::Repartition(repartition) => {
820 let new_filter =
821 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
822 .map(LogicalPlan::Filter)?;
823 insert_below(LogicalPlan::Repartition(repartition), new_filter)
824 }
825 LogicalPlan::Distinct(distinct) => {
826 let new_filter =
827 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
828 .map(LogicalPlan::Filter)?;
829 insert_below(LogicalPlan::Distinct(distinct), new_filter)
830 }
831 LogicalPlan::Sort(sort) => {
832 let new_filter =
833 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
834 .map(LogicalPlan::Filter)?;
835 insert_below(LogicalPlan::Sort(sort), new_filter)
836 }
837 LogicalPlan::SubqueryAlias(subquery_alias) => {
838 let mut replace_map = HashMap::new();
839 for (i, (qualifier, field)) in
840 subquery_alias.input.schema().iter().enumerate()
841 {
842 let (sub_qualifier, sub_field) =
843 subquery_alias.schema.qualified_field(i);
844 replace_map.insert(
845 qualified_name(sub_qualifier, sub_field.name()),
846 Expr::Column(Column::new(qualifier.cloned(), field.name())),
847 );
848 }
849 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
850
851 let new_filter = LogicalPlan::Filter(Filter::try_new(
852 new_predicate,
853 Arc::clone(&subquery_alias.input),
854 )?);
855 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
856 }
857 LogicalPlan::Projection(projection) => {
858 let predicates = split_conjunction_owned(filter.predicate.clone());
859 let (new_projection, keep_predicate) =
860 rewrite_projection(predicates, projection)?;
861 if new_projection.transformed {
862 match keep_predicate {
863 None => Ok(new_projection),
864 Some(keep_predicate) => new_projection.map_data(|child_plan| {
865 Filter::try_new(keep_predicate, Arc::new(child_plan))
866 .map(LogicalPlan::Filter)
867 }),
868 }
869 } else {
870 filter.input = Arc::new(new_projection.data);
871 Ok(Transformed::no(LogicalPlan::Filter(filter)))
872 }
873 }
874 LogicalPlan::Unnest(mut unnest) => {
875 let predicates = split_conjunction_owned(filter.predicate.clone());
876 let mut non_unnest_predicates = vec![];
877 let mut unnest_predicates = vec![];
878 for predicate in predicates {
879 let mut accum: HashSet<Column> = HashSet::new();
881 expr_to_columns(&predicate, &mut accum)?;
882
883 if unnest.list_type_columns.iter().any(|(_, unnest_list)| {
884 accum.contains(&unnest_list.output_column)
885 }) {
886 unnest_predicates.push(predicate);
887 } else {
888 non_unnest_predicates.push(predicate);
889 }
890 }
891
892 if non_unnest_predicates.is_empty() {
895 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
896 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
897 }
898
899 let unnest_input = std::mem::take(&mut unnest.input);
908
909 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
910 conjunction(non_unnest_predicates).unwrap(), unnest_input,
912 )?);
913
914 let unnest_plan =
918 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
919
920 match conjunction(unnest_predicates) {
921 None => Ok(unnest_plan),
922 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
923 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
924 ))),
925 }
926 }
927 LogicalPlan::Union(ref union) => {
928 let mut inputs = Vec::with_capacity(union.inputs.len());
929 for input in &union.inputs {
930 let mut replace_map = HashMap::new();
931 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
932 let (union_qualifier, union_field) =
933 union.schema.qualified_field(i);
934 replace_map.insert(
935 qualified_name(union_qualifier, union_field.name()),
936 Expr::Column(Column::new(qualifier.cloned(), field.name())),
937 );
938 }
939
940 let push_predicate =
941 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
942 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
943 push_predicate,
944 Arc::clone(input),
945 )?)))
946 }
947 Ok(Transformed::yes(LogicalPlan::Union(Union {
948 inputs,
949 schema: Arc::clone(&plan_schema),
950 })))
951 }
952 LogicalPlan::Aggregate(agg) => {
953 let group_expr_columns = agg
955 .group_expr
956 .iter()
957 .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
958 .collect::<Result<HashSet<_>>>()?;
959
960 let predicates = split_conjunction_owned(filter.predicate);
961
962 let mut keep_predicates = vec![];
963 let mut push_predicates = vec![];
964 for expr in predicates {
965 let cols = expr.column_refs();
966 if cols.iter().all(|c| group_expr_columns.contains(c)) {
967 push_predicates.push(expr);
968 } else {
969 keep_predicates.push(expr);
970 }
971 }
972
973 let mut replace_map = HashMap::new();
977 for expr in &agg.group_expr {
978 replace_map.insert(expr.schema_name().to_string(), expr.clone());
979 }
980 let replaced_push_predicates = push_predicates
981 .into_iter()
982 .map(|expr| replace_cols_by_name(expr, &replace_map))
983 .collect::<Result<Vec<_>>>()?;
984
985 let agg_input = Arc::clone(&agg.input);
986 Transformed::yes(LogicalPlan::Aggregate(agg))
987 .transform_data(|new_plan| {
988 if let Some(predicate) = conjunction(replaced_push_predicates) {
990 let new_filter = make_filter(predicate, agg_input)?;
991 insert_below(new_plan, new_filter)
992 } else {
993 Ok(Transformed::no(new_plan))
994 }
995 })?
996 .map_data(|child_plan| {
997 if let Some(predicate) = conjunction(keep_predicates) {
1000 make_filter(predicate, Arc::new(child_plan))
1001 } else {
1002 Ok(child_plan)
1003 }
1004 })
1005 }
1006 LogicalPlan::Window(window) => {
1017 let extract_partition_keys = |func: &WindowFunction| {
1023 func.params
1024 .partition_by
1025 .iter()
1026 .map(|c| Column::from_qualified_name(c.schema_name().to_string()))
1027 .collect::<HashSet<_>>()
1028 };
1029 let potential_partition_keys = window
1030 .window_expr
1031 .iter()
1032 .map(|e| {
1033 match e {
1034 Expr::WindowFunction(window_func) => {
1035 extract_partition_keys(window_func)
1036 }
1037 Expr::Alias(alias) => {
1038 if let Expr::WindowFunction(window_func) =
1039 alias.expr.as_ref()
1040 {
1041 extract_partition_keys(window_func)
1042 } else {
1043 unreachable!()
1045 }
1046 }
1047 _ => {
1048 unreachable!()
1050 }
1051 }
1052 })
1053 .reduce(|a, b| &a & &b)
1056 .unwrap_or_default();
1057
1058 let predicates = split_conjunction_owned(filter.predicate);
1059 let mut keep_predicates = vec![];
1060 let mut push_predicates = vec![];
1061 for expr in predicates {
1062 let cols = expr.column_refs();
1063 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1064 push_predicates.push(expr);
1065 } else {
1066 keep_predicates.push(expr);
1067 }
1068 }
1069
1070 let window_input = Arc::clone(&window.input);
1079 Transformed::yes(LogicalPlan::Window(window))
1080 .transform_data(|new_plan| {
1081 if let Some(predicate) = conjunction(push_predicates) {
1083 let new_filter = make_filter(predicate, window_input)?;
1084 insert_below(new_plan, new_filter)
1085 } else {
1086 Ok(Transformed::no(new_plan))
1087 }
1088 })?
1089 .map_data(|child_plan| {
1090 if let Some(predicate) = conjunction(keep_predicates) {
1093 make_filter(predicate, Arc::new(child_plan))
1094 } else {
1095 Ok(child_plan)
1096 }
1097 })
1098 }
1099 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1100 LogicalPlan::TableScan(scan) => {
1101 let filter_predicates = split_conjunction(&filter.predicate);
1102
1103 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1104 filter_predicates
1105 .into_iter()
1106 .partition(|pred| pred.is_volatile());
1107
1108 let supported_filters = scan
1110 .source
1111 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1112 if non_volatile_filters.len() != supported_filters.len() {
1113 return internal_err!(
1114 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1115 supported_filters.len(),
1116 non_volatile_filters.len());
1117 }
1118
1119 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1121
1122 let new_scan_filters = zip
1123 .clone()
1124 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1125 .map(|(pred, _)| pred);
1126
1127 let new_scan_filters: Vec<Expr> = scan
1129 .filters
1130 .iter()
1131 .chain(new_scan_filters)
1132 .unique()
1133 .cloned()
1134 .collect();
1135
1136 let new_predicate: Vec<Expr> = zip
1138 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1139 .map(|(pred, _)| pred)
1140 .chain(volatile_filters)
1141 .cloned()
1142 .collect();
1143
1144 let new_scan = LogicalPlan::TableScan(TableScan {
1145 filters: new_scan_filters,
1146 ..scan
1147 });
1148
1149 Transformed::yes(new_scan).transform_data(|new_scan| {
1150 if let Some(predicate) = conjunction(new_predicate) {
1151 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1152 } else {
1153 Ok(Transformed::no(new_scan))
1154 }
1155 })
1156 }
1157 LogicalPlan::Extension(extension_plan) => {
1158 if extension_plan.node.inputs().is_empty() {
1161 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1162 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1163 }
1164 let prevent_cols =
1165 extension_plan.node.prevent_predicate_push_down_columns();
1166
1167 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1171 .iter()
1172 .map(|expr| {
1173 let cols = expr.column_refs();
1174 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1175 Ok(false) } else {
1177 Ok(true) }
1179 })
1180 .collect::<Result<Vec<_>>>()?;
1181
1182 if predicate_push_or_keep.iter().all(|&x| !x) {
1184 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1185 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1186 }
1187
1188 let mut keep_predicates = vec![];
1190 let mut push_predicates = vec![];
1191 for (push, expr) in predicate_push_or_keep
1192 .into_iter()
1193 .zip(split_conjunction_owned(filter.predicate).into_iter())
1194 {
1195 if !push {
1196 keep_predicates.push(expr);
1197 } else {
1198 push_predicates.push(expr);
1199 }
1200 }
1201
1202 let new_children = match conjunction(push_predicates) {
1203 Some(predicate) => extension_plan
1204 .node
1205 .inputs()
1206 .into_iter()
1207 .map(|child| {
1208 Ok(LogicalPlan::Filter(Filter::try_new(
1209 predicate.clone(),
1210 Arc::new(child.clone()),
1211 )?))
1212 })
1213 .collect::<Result<Vec<_>>>()?,
1214 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1215 };
1216 let child_plan = LogicalPlan::Extension(extension_plan);
1218 let new_extension =
1219 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1220
1221 let new_plan = match conjunction(keep_predicates) {
1222 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1223 predicate,
1224 Arc::new(new_extension),
1225 )?),
1226 None => new_extension,
1227 };
1228 Ok(Transformed::yes(new_plan))
1229 }
1230 child => {
1231 filter.input = Arc::new(child);
1232 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1233 }
1234 }
1235 }
1236}
1237
1238fn rewrite_projection(
1266 predicates: Vec<Expr>,
1267 mut projection: Projection,
1268) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1269 let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1273 .schema
1274 .iter()
1275 .zip(projection.expr.iter())
1276 .map(|((qualifier, field), expr)| {
1277 let expr = expr.clone().unalias();
1279
1280 (qualified_name(qualifier, field.name()), expr)
1281 })
1282 .partition(|(_, value)| value.is_volatile());
1283
1284 let mut push_predicates = vec![];
1285 let mut keep_predicates = vec![];
1286 for expr in predicates {
1287 if contain(&expr, &volatile_map) {
1288 keep_predicates.push(expr);
1289 } else {
1290 push_predicates.push(expr);
1291 }
1292 }
1293
1294 match conjunction(push_predicates) {
1295 Some(expr) => {
1296 let new_filter = LogicalPlan::Filter(Filter::try_new(
1299 replace_cols_by_name(expr, &non_volatile_map)?,
1300 std::mem::take(&mut projection.input),
1301 )?);
1302
1303 projection.input = Arc::new(new_filter);
1304
1305 Ok((
1306 Transformed::yes(LogicalPlan::Projection(projection)),
1307 conjunction(keep_predicates),
1308 ))
1309 }
1310 None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1311 }
1312}
1313
1314pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1316 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1317}
1318
1319fn insert_below(
1333 plan: LogicalPlan,
1334 new_child: LogicalPlan,
1335) -> Result<Transformed<LogicalPlan>> {
1336 let mut new_child = Some(new_child);
1337 let transformed_plan = plan.map_children(|_child| {
1338 if let Some(new_child) = new_child.take() {
1339 Ok(Transformed::yes(new_child))
1340 } else {
1341 internal_err!("node had more than one input")
1343 }
1344 })?;
1345
1346 if new_child.is_some() {
1348 return internal_err!("node had no inputs");
1349 }
1350
1351 Ok(transformed_plan)
1352}
1353
1354impl PushDownFilter {
1355 #[allow(missing_docs)]
1356 pub fn new() -> Self {
1357 Self {}
1358 }
1359}
1360
1361pub fn replace_cols_by_name(
1363 e: Expr,
1364 replace_map: &HashMap<String, Expr>,
1365) -> Result<Expr> {
1366 e.transform_up(|expr| {
1367 Ok(if let Expr::Column(c) = &expr {
1368 match replace_map.get(&c.flat_name()) {
1369 Some(new_c) => Transformed::yes(new_c.clone()),
1370 None => Transformed::no(expr),
1371 }
1372 } else {
1373 Transformed::no(expr)
1374 })
1375 })
1376 .data()
1377}
1378
1379fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1381 let mut is_contain = false;
1382 e.apply(|expr| {
1383 Ok(if let Expr::Column(c) = &expr {
1384 match check_map.get(&c.flat_name()) {
1385 Some(_) => {
1386 is_contain = true;
1387 TreeNodeRecursion::Stop
1388 }
1389 None => TreeNodeRecursion::Continue,
1390 }
1391 } else {
1392 TreeNodeRecursion::Continue
1393 })
1394 })
1395 .unwrap();
1396 is_contain
1397}
1398
1399#[cfg(test)]
1400mod tests {
1401 use std::any::Any;
1402 use std::cmp::Ordering;
1403 use std::fmt::{Debug, Formatter};
1404
1405 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1406 use async_trait::async_trait;
1407
1408 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1409 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1410 use datafusion_expr::logical_plan::table_scan;
1411 use datafusion_expr::{
1412 col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1413 LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1414 TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1415 WindowFunctionDefinition,
1416 };
1417
1418 use crate::assert_optimized_plan_eq_snapshot;
1419 use crate::optimizer::Optimizer;
1420 use crate::simplify_expressions::SimplifyExpressions;
1421 use crate::test::*;
1422 use crate::OptimizerContext;
1423 use datafusion_expr::test::function_stub::sum;
1424 use insta::assert_snapshot;
1425
1426 use super::*;
1427
1428 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1429
1430 macro_rules! assert_optimized_plan_equal {
1431 (
1432 $plan:expr,
1433 @ $expected:literal $(,)?
1434 ) => {{
1435 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1436 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1437 assert_optimized_plan_eq_snapshot!(
1438 optimizer_ctx,
1439 rules,
1440 $plan,
1441 @ $expected,
1442 )
1443 }};
1444 }
1445
1446 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1447 (
1448 $plan:expr,
1449 @ $expected:literal $(,)?
1450 ) => {{
1451 let optimizer = Optimizer::with_rules(vec![
1452 Arc::new(SimplifyExpressions::new()),
1453 Arc::new(PushDownFilter::new()),
1454 ]);
1455 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1456 assert_snapshot!(optimized_plan, @ $expected);
1457 Ok::<(), DataFusionError>(())
1458 }};
1459 }
1460
1461 #[test]
1462 fn filter_before_projection() -> Result<()> {
1463 let table_scan = test_table_scan()?;
1464 let plan = LogicalPlanBuilder::from(table_scan)
1465 .project(vec![col("a"), col("b")])?
1466 .filter(col("a").eq(lit(1i64)))?
1467 .build()?;
1468 assert_optimized_plan_equal!(
1470 plan,
1471 @r"
1472 Projection: test.a, test.b
1473 TableScan: test, full_filters=[test.a = Int64(1)]
1474 "
1475 )
1476 }
1477
1478 #[test]
1479 fn filter_after_limit() -> Result<()> {
1480 let table_scan = test_table_scan()?;
1481 let plan = LogicalPlanBuilder::from(table_scan)
1482 .project(vec![col("a"), col("b")])?
1483 .limit(0, Some(10))?
1484 .filter(col("a").eq(lit(1i64)))?
1485 .build()?;
1486 assert_optimized_plan_equal!(
1488 plan,
1489 @r"
1490 Filter: test.a = Int64(1)
1491 Limit: skip=0, fetch=10
1492 Projection: test.a, test.b
1493 TableScan: test
1494 "
1495 )
1496 }
1497
1498 #[test]
1499 fn filter_no_columns() -> Result<()> {
1500 let table_scan = test_table_scan()?;
1501 let plan = LogicalPlanBuilder::from(table_scan)
1502 .filter(lit(0i64).eq(lit(1i64)))?
1503 .build()?;
1504 assert_optimized_plan_equal!(
1505 plan,
1506 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1507 )
1508 }
1509
1510 #[test]
1511 fn filter_jump_2_plans() -> Result<()> {
1512 let table_scan = test_table_scan()?;
1513 let plan = LogicalPlanBuilder::from(table_scan)
1514 .project(vec![col("a"), col("b"), col("c")])?
1515 .project(vec![col("c"), col("b")])?
1516 .filter(col("a").eq(lit(1i64)))?
1517 .build()?;
1518 assert_optimized_plan_equal!(
1520 plan,
1521 @r"
1522 Projection: test.c, test.b
1523 Projection: test.a, test.b, test.c
1524 TableScan: test, full_filters=[test.a = Int64(1)]
1525 "
1526 )
1527 }
1528
1529 #[test]
1530 fn filter_move_agg() -> Result<()> {
1531 let table_scan = test_table_scan()?;
1532 let plan = LogicalPlanBuilder::from(table_scan)
1533 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1534 .filter(col("a").gt(lit(10i64)))?
1535 .build()?;
1536 assert_optimized_plan_equal!(
1538 plan,
1539 @r"
1540 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1541 TableScan: test, full_filters=[test.a > Int64(10)]
1542 "
1543 )
1544 }
1545
1546 #[test]
1547 fn filter_complex_group_by() -> Result<()> {
1548 let table_scan = test_table_scan()?;
1549 let plan = LogicalPlanBuilder::from(table_scan)
1550 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1551 .filter(col("b").gt(lit(10i64)))?
1552 .build()?;
1553 assert_optimized_plan_equal!(
1554 plan,
1555 @r"
1556 Filter: test.b > Int64(10)
1557 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1558 TableScan: test
1559 "
1560 )
1561 }
1562
1563 #[test]
1564 fn push_agg_need_replace_expr() -> Result<()> {
1565 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1566 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1567 .filter(col("test.b + test.a").gt(lit(10i64)))?
1568 .build()?;
1569 assert_optimized_plan_equal!(
1570 plan,
1571 @r"
1572 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1573 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1574 "
1575 )
1576 }
1577
1578 #[test]
1579 fn filter_keep_agg() -> Result<()> {
1580 let table_scan = test_table_scan()?;
1581 let plan = LogicalPlanBuilder::from(table_scan)
1582 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1583 .filter(col("b").gt(lit(10i64)))?
1584 .build()?;
1585 assert_optimized_plan_equal!(
1587 plan,
1588 @r"
1589 Filter: b > Int64(10)
1590 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1591 TableScan: test
1592 "
1593 )
1594 }
1595
1596 #[test]
1598 fn filter_move_window() -> Result<()> {
1599 let table_scan = test_table_scan()?;
1600
1601 let window = Expr::from(WindowFunction::new(
1602 WindowFunctionDefinition::WindowUDF(
1603 datafusion_functions_window::rank::rank_udwf(),
1604 ),
1605 vec![],
1606 ))
1607 .partition_by(vec![col("a"), col("b")])
1608 .order_by(vec![col("c").sort(true, true)])
1609 .build()
1610 .unwrap();
1611
1612 let plan = LogicalPlanBuilder::from(table_scan)
1613 .window(vec![window])?
1614 .filter(col("b").gt(lit(10i64)))?
1615 .build()?;
1616
1617 assert_optimized_plan_equal!(
1618 plan,
1619 @r"
1620 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1621 TableScan: test, full_filters=[test.b > Int64(10)]
1622 "
1623 )
1624 }
1625
1626 #[test]
1629 fn filter_move_complex_window() -> Result<()> {
1630 let table_scan = test_table_scan()?;
1631
1632 let window = Expr::from(WindowFunction::new(
1633 WindowFunctionDefinition::WindowUDF(
1634 datafusion_functions_window::rank::rank_udwf(),
1635 ),
1636 vec![],
1637 ))
1638 .partition_by(vec![col("a"), col("b")])
1639 .order_by(vec![col("c").sort(true, true)])
1640 .build()
1641 .unwrap();
1642
1643 let plan = LogicalPlanBuilder::from(table_scan)
1644 .window(vec![window])?
1645 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1646 .build()?;
1647
1648 assert_optimized_plan_equal!(
1649 plan,
1650 @r"
1651 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1652 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1653 "
1654 )
1655 }
1656
1657 #[test]
1659 fn filter_move_partial_window() -> Result<()> {
1660 let table_scan = test_table_scan()?;
1661
1662 let window = Expr::from(WindowFunction::new(
1663 WindowFunctionDefinition::WindowUDF(
1664 datafusion_functions_window::rank::rank_udwf(),
1665 ),
1666 vec![],
1667 ))
1668 .partition_by(vec![col("a")])
1669 .order_by(vec![col("c").sort(true, true)])
1670 .build()
1671 .unwrap();
1672
1673 let plan = LogicalPlanBuilder::from(table_scan)
1674 .window(vec![window])?
1675 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1676 .build()?;
1677
1678 assert_optimized_plan_equal!(
1679 plan,
1680 @r"
1681 Filter: test.b = Int64(1)
1682 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1683 TableScan: test, full_filters=[test.a > Int64(10)]
1684 "
1685 )
1686 }
1687
1688 #[test]
1691 fn filter_expression_keep_window() -> Result<()> {
1692 let table_scan = test_table_scan()?;
1693
1694 let window = Expr::from(WindowFunction::new(
1695 WindowFunctionDefinition::WindowUDF(
1696 datafusion_functions_window::rank::rank_udwf(),
1697 ),
1698 vec![],
1699 ))
1700 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1702 .build()
1703 .unwrap();
1704
1705 let plan = LogicalPlanBuilder::from(table_scan)
1706 .window(vec![window])?
1707 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1710 .build()?;
1711
1712 assert_optimized_plan_equal!(
1713 plan,
1714 @r"
1715 Filter: test.a + test.b > Int64(10)
1716 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1717 TableScan: test
1718 "
1719 )
1720 }
1721
1722 #[test]
1724 fn filter_order_keep_window() -> Result<()> {
1725 let table_scan = test_table_scan()?;
1726
1727 let window = Expr::from(WindowFunction::new(
1728 WindowFunctionDefinition::WindowUDF(
1729 datafusion_functions_window::rank::rank_udwf(),
1730 ),
1731 vec![],
1732 ))
1733 .partition_by(vec![col("a")])
1734 .order_by(vec![col("c").sort(true, true)])
1735 .build()
1736 .unwrap();
1737
1738 let plan = LogicalPlanBuilder::from(table_scan)
1739 .window(vec![window])?
1740 .filter(col("c").gt(lit(10i64)))?
1741 .build()?;
1742
1743 assert_optimized_plan_equal!(
1744 plan,
1745 @r"
1746 Filter: test.c > Int64(10)
1747 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1748 TableScan: test
1749 "
1750 )
1751 }
1752
1753 #[test]
1756 fn filter_multiple_windows_common_partitions() -> Result<()> {
1757 let table_scan = test_table_scan()?;
1758
1759 let window1 = Expr::from(WindowFunction::new(
1760 WindowFunctionDefinition::WindowUDF(
1761 datafusion_functions_window::rank::rank_udwf(),
1762 ),
1763 vec![],
1764 ))
1765 .partition_by(vec![col("a")])
1766 .order_by(vec![col("c").sort(true, true)])
1767 .build()
1768 .unwrap();
1769
1770 let window2 = Expr::from(WindowFunction::new(
1771 WindowFunctionDefinition::WindowUDF(
1772 datafusion_functions_window::rank::rank_udwf(),
1773 ),
1774 vec![],
1775 ))
1776 .partition_by(vec![col("b"), col("a")])
1777 .order_by(vec![col("c").sort(true, true)])
1778 .build()
1779 .unwrap();
1780
1781 let plan = LogicalPlanBuilder::from(table_scan)
1782 .window(vec![window1, window2])?
1783 .filter(col("a").gt(lit(10i64)))? .build()?;
1785
1786 assert_optimized_plan_equal!(
1787 plan,
1788 @r"
1789 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1790 TableScan: test, full_filters=[test.a > Int64(10)]
1791 "
1792 )
1793 }
1794
1795 #[test]
1798 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1799 let table_scan = test_table_scan()?;
1800
1801 let window1 = Expr::from(WindowFunction::new(
1802 WindowFunctionDefinition::WindowUDF(
1803 datafusion_functions_window::rank::rank_udwf(),
1804 ),
1805 vec![],
1806 ))
1807 .partition_by(vec![col("a")])
1808 .order_by(vec![col("c").sort(true, true)])
1809 .build()
1810 .unwrap();
1811
1812 let window2 = Expr::from(WindowFunction::new(
1813 WindowFunctionDefinition::WindowUDF(
1814 datafusion_functions_window::rank::rank_udwf(),
1815 ),
1816 vec![],
1817 ))
1818 .partition_by(vec![col("b"), col("a")])
1819 .order_by(vec![col("c").sort(true, true)])
1820 .build()
1821 .unwrap();
1822
1823 let plan = LogicalPlanBuilder::from(table_scan)
1824 .window(vec![window1, window2])?
1825 .filter(col("b").gt(lit(10i64)))? .build()?;
1827
1828 assert_optimized_plan_equal!(
1829 plan,
1830 @r"
1831 Filter: test.b > Int64(10)
1832 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1833 TableScan: test
1834 "
1835 )
1836 }
1837
1838 #[test]
1840 fn alias() -> Result<()> {
1841 let table_scan = test_table_scan()?;
1842 let plan = LogicalPlanBuilder::from(table_scan)
1843 .project(vec![col("a").alias("b"), col("c")])?
1844 .filter(col("b").eq(lit(1i64)))?
1845 .build()?;
1846 assert_optimized_plan_equal!(
1848 plan,
1849 @r"
1850 Projection: test.a AS b, test.c
1851 TableScan: test, full_filters=[test.a = Int64(1)]
1852 "
1853 )
1854 }
1855
1856 fn add(left: Expr, right: Expr) -> Expr {
1857 Expr::BinaryExpr(BinaryExpr::new(
1858 Box::new(left),
1859 Operator::Plus,
1860 Box::new(right),
1861 ))
1862 }
1863
1864 fn multiply(left: Expr, right: Expr) -> Expr {
1865 Expr::BinaryExpr(BinaryExpr::new(
1866 Box::new(left),
1867 Operator::Multiply,
1868 Box::new(right),
1869 ))
1870 }
1871
1872 #[test]
1874 fn complex_expression() -> Result<()> {
1875 let table_scan = test_table_scan()?;
1876 let plan = LogicalPlanBuilder::from(table_scan)
1877 .project(vec![
1878 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1879 col("c"),
1880 ])?
1881 .filter(col("b").eq(lit(1i64)))?
1882 .build()?;
1883
1884 assert_snapshot!(plan,
1886 @r"
1887 Filter: b = Int64(1)
1888 Projection: test.a * Int32(2) + test.c AS b, test.c
1889 TableScan: test
1890 ",
1891 );
1892 assert_optimized_plan_equal!(
1894 plan,
1895 @r"
1896 Projection: test.a * Int32(2) + test.c AS b, test.c
1897 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1898 "
1899 )
1900 }
1901
1902 #[test]
1904 fn complex_plan() -> Result<()> {
1905 let table_scan = test_table_scan()?;
1906 let plan = LogicalPlanBuilder::from(table_scan)
1907 .project(vec![
1908 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1909 col("c"),
1910 ])?
1911 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
1913 .filter(col("a").eq(lit(1i64)))?
1914 .build()?;
1915
1916 assert_snapshot!(plan,
1918 @r"
1919 Filter: a = Int64(1)
1920 Projection: b * Int32(3) AS a, test.c
1921 Projection: test.a * Int32(2) + test.c AS b, test.c
1922 TableScan: test
1923 ",
1924 );
1925 assert_optimized_plan_equal!(
1927 plan,
1928 @r"
1929 Projection: b * Int32(3) AS a, test.c
1930 Projection: test.a * Int32(2) + test.c AS b, test.c
1931 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
1932 "
1933 )
1934 }
1935
1936 #[derive(Debug, PartialEq, Eq, Hash)]
1937 struct NoopPlan {
1938 input: Vec<LogicalPlan>,
1939 schema: DFSchemaRef,
1940 }
1941
1942 impl PartialOrd for NoopPlan {
1944 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1945 self.input.partial_cmp(&other.input)
1946 }
1947 }
1948
1949 impl UserDefinedLogicalNodeCore for NoopPlan {
1950 fn name(&self) -> &str {
1951 "NoopPlan"
1952 }
1953
1954 fn inputs(&self) -> Vec<&LogicalPlan> {
1955 self.input.iter().collect()
1956 }
1957
1958 fn schema(&self) -> &DFSchemaRef {
1959 &self.schema
1960 }
1961
1962 fn expressions(&self) -> Vec<Expr> {
1963 self.input
1964 .iter()
1965 .flat_map(|child| child.expressions())
1966 .collect()
1967 }
1968
1969 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
1970 HashSet::from_iter(vec!["c".to_string()])
1971 }
1972
1973 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1974 write!(f, "NoopPlan")
1975 }
1976
1977 fn with_exprs_and_inputs(
1978 &self,
1979 _exprs: Vec<Expr>,
1980 inputs: Vec<LogicalPlan>,
1981 ) -> Result<Self> {
1982 Ok(Self {
1983 input: inputs,
1984 schema: Arc::clone(&self.schema),
1985 })
1986 }
1987
1988 fn supports_limit_pushdown(&self) -> bool {
1989 false }
1991 }
1992
1993 #[test]
1994 fn user_defined_plan() -> Result<()> {
1995 let table_scan = test_table_scan()?;
1996
1997 let custom_plan = LogicalPlan::Extension(Extension {
1998 node: Arc::new(NoopPlan {
1999 input: vec![table_scan.clone()],
2000 schema: Arc::clone(table_scan.schema()),
2001 }),
2002 });
2003 let plan = LogicalPlanBuilder::from(custom_plan)
2004 .filter(col("a").eq(lit(1i64)))?
2005 .build()?;
2006
2007 assert_optimized_plan_equal!(
2009 plan,
2010 @r"
2011 NoopPlan
2012 TableScan: test, full_filters=[test.a = Int64(1)]
2013 "
2014 )?;
2015
2016 let custom_plan = LogicalPlan::Extension(Extension {
2017 node: Arc::new(NoopPlan {
2018 input: vec![table_scan.clone()],
2019 schema: Arc::clone(table_scan.schema()),
2020 }),
2021 });
2022 let plan = LogicalPlanBuilder::from(custom_plan)
2023 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2024 .build()?;
2025
2026 assert_optimized_plan_equal!(
2028 plan,
2029 @r"
2030 Filter: test.c = Int64(2)
2031 NoopPlan
2032 TableScan: test, full_filters=[test.a = Int64(1)]
2033 "
2034 )?;
2035
2036 let custom_plan = LogicalPlan::Extension(Extension {
2037 node: Arc::new(NoopPlan {
2038 input: vec![table_scan.clone(), table_scan.clone()],
2039 schema: Arc::clone(table_scan.schema()),
2040 }),
2041 });
2042 let plan = LogicalPlanBuilder::from(custom_plan)
2043 .filter(col("a").eq(lit(1i64)))?
2044 .build()?;
2045
2046 assert_optimized_plan_equal!(
2048 plan,
2049 @r"
2050 NoopPlan
2051 TableScan: test, full_filters=[test.a = Int64(1)]
2052 TableScan: test, full_filters=[test.a = Int64(1)]
2053 "
2054 )?;
2055
2056 let custom_plan = LogicalPlan::Extension(Extension {
2057 node: Arc::new(NoopPlan {
2058 input: vec![table_scan.clone(), table_scan.clone()],
2059 schema: Arc::clone(table_scan.schema()),
2060 }),
2061 });
2062 let plan = LogicalPlanBuilder::from(custom_plan)
2063 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2064 .build()?;
2065
2066 assert_optimized_plan_equal!(
2068 plan,
2069 @r"
2070 Filter: test.c = Int64(2)
2071 NoopPlan
2072 TableScan: test, full_filters=[test.a = Int64(1)]
2073 TableScan: test, full_filters=[test.a = Int64(1)]
2074 "
2075 )
2076 }
2077
2078 #[test]
2081 fn multi_filter() -> Result<()> {
2082 let table_scan = test_table_scan()?;
2084 let plan = LogicalPlanBuilder::from(table_scan)
2085 .project(vec![col("a").alias("b"), col("c")])?
2086 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2087 .filter(col("b").gt(lit(10i64)))?
2088 .filter(col("sum(test.c)").gt(lit(10i64)))?
2089 .build()?;
2090
2091 assert_snapshot!(plan,
2093 @r"
2094 Filter: sum(test.c) > Int64(10)
2095 Filter: b > Int64(10)
2096 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2097 Projection: test.a AS b, test.c
2098 TableScan: test
2099 ",
2100 );
2101 assert_optimized_plan_equal!(
2103 plan,
2104 @r"
2105 Filter: sum(test.c) > Int64(10)
2106 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2107 Projection: test.a AS b, test.c
2108 TableScan: test, full_filters=[test.a > Int64(10)]
2109 "
2110 )
2111 }
2112
2113 #[test]
2116 fn split_filter() -> Result<()> {
2117 let table_scan = test_table_scan()?;
2119 let plan = LogicalPlanBuilder::from(table_scan)
2120 .project(vec![col("a").alias("b"), col("c")])?
2121 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2122 .filter(and(
2123 col("sum(test.c)").gt(lit(10i64)),
2124 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2125 ))?
2126 .build()?;
2127
2128 assert_snapshot!(plan,
2130 @r"
2131 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2132 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2133 Projection: test.a AS b, test.c
2134 TableScan: test
2135 ",
2136 );
2137 assert_optimized_plan_equal!(
2139 plan,
2140 @r"
2141 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2142 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2143 Projection: test.a AS b, test.c
2144 TableScan: test, full_filters=[test.a > Int64(10)]
2145 "
2146 )
2147 }
2148
2149 #[test]
2151 fn double_limit() -> Result<()> {
2152 let table_scan = test_table_scan()?;
2153 let plan = LogicalPlanBuilder::from(table_scan)
2154 .project(vec![col("a"), col("b")])?
2155 .limit(0, Some(20))?
2156 .limit(0, Some(10))?
2157 .project(vec![col("a"), col("b")])?
2158 .filter(col("a").eq(lit(1i64)))?
2159 .build()?;
2160 assert_optimized_plan_equal!(
2162 plan,
2163 @r"
2164 Projection: test.a, test.b
2165 Filter: test.a = Int64(1)
2166 Limit: skip=0, fetch=10
2167 Limit: skip=0, fetch=20
2168 Projection: test.a, test.b
2169 TableScan: test
2170 "
2171 )
2172 }
2173
2174 #[test]
2175 fn union_all() -> Result<()> {
2176 let table_scan = test_table_scan()?;
2177 let table_scan2 = test_table_scan_with_name("test2")?;
2178 let plan = LogicalPlanBuilder::from(table_scan)
2179 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2180 .filter(col("a").eq(lit(1i64)))?
2181 .build()?;
2182 assert_optimized_plan_equal!(
2184 plan,
2185 @r"
2186 Union
2187 TableScan: test, full_filters=[test.a = Int64(1)]
2188 TableScan: test2, full_filters=[test2.a = Int64(1)]
2189 "
2190 )
2191 }
2192
2193 #[test]
2194 fn union_all_on_projection() -> Result<()> {
2195 let table_scan = test_table_scan()?;
2196 let table = LogicalPlanBuilder::from(table_scan)
2197 .project(vec![col("a").alias("b")])?
2198 .alias("test2")?;
2199
2200 let plan = table
2201 .clone()
2202 .union(table.build()?)?
2203 .filter(col("b").eq(lit(1i64)))?
2204 .build()?;
2205
2206 assert_optimized_plan_equal!(
2208 plan,
2209 @r"
2210 Union
2211 SubqueryAlias: test2
2212 Projection: test.a AS b
2213 TableScan: test, full_filters=[test.a = Int64(1)]
2214 SubqueryAlias: test2
2215 Projection: test.a AS b
2216 TableScan: test, full_filters=[test.a = Int64(1)]
2217 "
2218 )
2219 }
2220
2221 #[test]
2222 fn test_union_different_schema() -> Result<()> {
2223 let left = LogicalPlanBuilder::from(test_table_scan()?)
2224 .project(vec![col("a"), col("b"), col("c")])?
2225 .build()?;
2226
2227 let schema = Schema::new(vec![
2228 Field::new("d", DataType::UInt32, false),
2229 Field::new("e", DataType::UInt32, false),
2230 Field::new("f", DataType::UInt32, false),
2231 ]);
2232 let right = table_scan(Some("test1"), &schema, None)?
2233 .project(vec![col("d"), col("e"), col("f")])?
2234 .build()?;
2235 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2236 let plan = LogicalPlanBuilder::from(left)
2237 .cross_join(right)?
2238 .project(vec![col("test.a"), col("test1.d")])?
2239 .filter(filter)?
2240 .build()?;
2241
2242 assert_optimized_plan_equal!(
2243 plan,
2244 @r"
2245 Projection: test.a, test1.d
2246 Cross Join:
2247 Projection: test.a, test.b, test.c
2248 TableScan: test, full_filters=[test.a = Int32(1)]
2249 Projection: test1.d, test1.e, test1.f
2250 TableScan: test1, full_filters=[test1.d > Int32(2)]
2251 "
2252 )
2253 }
2254
2255 #[test]
2256 fn test_project_same_name_different_qualifier() -> Result<()> {
2257 let table_scan = test_table_scan()?;
2258 let left = LogicalPlanBuilder::from(table_scan)
2259 .project(vec![col("a"), col("b"), col("c")])?
2260 .build()?;
2261 let right_table_scan = test_table_scan_with_name("test1")?;
2262 let right = LogicalPlanBuilder::from(right_table_scan)
2263 .project(vec![col("a"), col("b"), col("c")])?
2264 .build()?;
2265 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2266 let plan = LogicalPlanBuilder::from(left)
2267 .cross_join(right)?
2268 .project(vec![col("test.a"), col("test1.a")])?
2269 .filter(filter)?
2270 .build()?;
2271
2272 assert_optimized_plan_equal!(
2273 plan,
2274 @r"
2275 Projection: test.a, test1.a
2276 Cross Join:
2277 Projection: test.a, test.b, test.c
2278 TableScan: test, full_filters=[test.a = Int32(1)]
2279 Projection: test1.a, test1.b, test1.c
2280 TableScan: test1, full_filters=[test1.a > Int32(2)]
2281 "
2282 )
2283 }
2284
2285 #[test]
2287 fn filter_2_breaks_limits() -> Result<()> {
2288 let table_scan = test_table_scan()?;
2289 let plan = LogicalPlanBuilder::from(table_scan)
2290 .project(vec![col("a")])?
2291 .filter(col("a").lt_eq(lit(1i64)))?
2292 .limit(0, Some(1))?
2293 .project(vec![col("a")])?
2294 .filter(col("a").gt_eq(lit(1i64)))?
2295 .build()?;
2296 assert_snapshot!(plan,
2300 @r"
2301 Filter: test.a >= Int64(1)
2302 Projection: test.a
2303 Limit: skip=0, fetch=1
2304 Filter: test.a <= Int64(1)
2305 Projection: test.a
2306 TableScan: test
2307 ",
2308 );
2309 assert_optimized_plan_equal!(
2310 plan,
2311 @r"
2312 Projection: test.a
2313 Filter: test.a >= Int64(1)
2314 Limit: skip=0, fetch=1
2315 Projection: test.a
2316 TableScan: test, full_filters=[test.a <= Int64(1)]
2317 "
2318 )
2319 }
2320
2321 #[test]
2323 fn two_filters_on_same_depth() -> Result<()> {
2324 let table_scan = test_table_scan()?;
2325 let plan = LogicalPlanBuilder::from(table_scan)
2326 .limit(0, Some(1))?
2327 .filter(col("a").lt_eq(lit(1i64)))?
2328 .filter(col("a").gt_eq(lit(1i64)))?
2329 .project(vec![col("a")])?
2330 .build()?;
2331
2332 assert_snapshot!(plan,
2334 @r"
2335 Projection: test.a
2336 Filter: test.a >= Int64(1)
2337 Filter: test.a <= Int64(1)
2338 Limit: skip=0, fetch=1
2339 TableScan: test
2340 ",
2341 );
2342 assert_optimized_plan_equal!(
2343 plan,
2344 @r"
2345 Projection: test.a
2346 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2347 Limit: skip=0, fetch=1
2348 TableScan: test
2349 "
2350 )
2351 }
2352
2353 #[test]
2356 fn filters_user_defined_node() -> Result<()> {
2357 let table_scan = test_table_scan()?;
2358 let plan = LogicalPlanBuilder::from(table_scan)
2359 .filter(col("a").lt_eq(lit(1i64)))?
2360 .build()?;
2361
2362 let plan = user_defined::new(plan);
2363
2364 assert_snapshot!(plan,
2366 @r"
2367 TestUserDefined
2368 Filter: test.a <= Int64(1)
2369 TableScan: test
2370 ",
2371 );
2372 assert_optimized_plan_equal!(
2373 plan,
2374 @r"
2375 TestUserDefined
2376 TableScan: test, full_filters=[test.a <= Int64(1)]
2377 "
2378 )
2379 }
2380
2381 #[test]
2383 fn filter_on_join_on_common_independent() -> Result<()> {
2384 let table_scan = test_table_scan()?;
2385 let left = LogicalPlanBuilder::from(table_scan).build()?;
2386 let right_table_scan = test_table_scan_with_name("test2")?;
2387 let right = LogicalPlanBuilder::from(right_table_scan)
2388 .project(vec![col("a")])?
2389 .build()?;
2390 let plan = LogicalPlanBuilder::from(left)
2391 .join(
2392 right,
2393 JoinType::Inner,
2394 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2395 None,
2396 )?
2397 .filter(col("test.a").lt_eq(lit(1i64)))?
2398 .build()?;
2399
2400 assert_snapshot!(plan,
2402 @r"
2403 Filter: test.a <= Int64(1)
2404 Inner Join: test.a = test2.a
2405 TableScan: test
2406 Projection: test2.a
2407 TableScan: test2
2408 ",
2409 );
2410 assert_optimized_plan_equal!(
2412 plan,
2413 @r"
2414 Inner Join: test.a = test2.a
2415 TableScan: test, full_filters=[test.a <= Int64(1)]
2416 Projection: test2.a
2417 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2418 "
2419 )
2420 }
2421
2422 #[test]
2424 fn filter_using_join_on_common_independent() -> Result<()> {
2425 let table_scan = test_table_scan()?;
2426 let left = LogicalPlanBuilder::from(table_scan).build()?;
2427 let right_table_scan = test_table_scan_with_name("test2")?;
2428 let right = LogicalPlanBuilder::from(right_table_scan)
2429 .project(vec![col("a")])?
2430 .build()?;
2431 let plan = LogicalPlanBuilder::from(left)
2432 .join_using(
2433 right,
2434 JoinType::Inner,
2435 vec![Column::from_name("a".to_string())],
2436 )?
2437 .filter(col("a").lt_eq(lit(1i64)))?
2438 .build()?;
2439
2440 assert_snapshot!(plan,
2442 @r"
2443 Filter: test.a <= Int64(1)
2444 Inner Join: Using test.a = test2.a
2445 TableScan: test
2446 Projection: test2.a
2447 TableScan: test2
2448 ",
2449 );
2450 assert_optimized_plan_equal!(
2452 plan,
2453 @r"
2454 Inner Join: Using test.a = test2.a
2455 TableScan: test, full_filters=[test.a <= Int64(1)]
2456 Projection: test2.a
2457 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2458 "
2459 )
2460 }
2461
2462 #[test]
2464 fn filter_join_on_common_dependent() -> Result<()> {
2465 let table_scan = test_table_scan()?;
2466 let left = LogicalPlanBuilder::from(table_scan)
2467 .project(vec![col("a"), col("c")])?
2468 .build()?;
2469 let right_table_scan = test_table_scan_with_name("test2")?;
2470 let right = LogicalPlanBuilder::from(right_table_scan)
2471 .project(vec![col("a"), col("b")])?
2472 .build()?;
2473 let plan = LogicalPlanBuilder::from(left)
2474 .join(
2475 right,
2476 JoinType::Inner,
2477 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2478 None,
2479 )?
2480 .filter(col("c").lt_eq(col("b")))?
2481 .build()?;
2482
2483 assert_snapshot!(plan,
2485 @r"
2486 Filter: test.c <= test2.b
2487 Inner Join: test.a = test2.a
2488 Projection: test.a, test.c
2489 TableScan: test
2490 Projection: test2.a, test2.b
2491 TableScan: test2
2492 ",
2493 );
2494 assert_optimized_plan_equal!(
2496 plan,
2497 @r"
2498 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2499 Projection: test.a, test.c
2500 TableScan: test
2501 Projection: test2.a, test2.b
2502 TableScan: test2
2503 "
2504 )
2505 }
2506
2507 #[test]
2509 fn filter_join_on_one_side() -> Result<()> {
2510 let table_scan = test_table_scan()?;
2511 let left = LogicalPlanBuilder::from(table_scan)
2512 .project(vec![col("a"), col("b")])?
2513 .build()?;
2514 let table_scan_right = test_table_scan_with_name("test2")?;
2515 let right = LogicalPlanBuilder::from(table_scan_right)
2516 .project(vec![col("a"), col("c")])?
2517 .build()?;
2518
2519 let plan = LogicalPlanBuilder::from(left)
2520 .join(
2521 right,
2522 JoinType::Inner,
2523 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2524 None,
2525 )?
2526 .filter(col("b").lt_eq(lit(1i64)))?
2527 .build()?;
2528
2529 assert_snapshot!(plan,
2531 @r"
2532 Filter: test.b <= Int64(1)
2533 Inner Join: test.a = test2.a
2534 Projection: test.a, test.b
2535 TableScan: test
2536 Projection: test2.a, test2.c
2537 TableScan: test2
2538 ",
2539 );
2540 assert_optimized_plan_equal!(
2541 plan,
2542 @r"
2543 Inner Join: test.a = test2.a
2544 Projection: test.a, test.b
2545 TableScan: test, full_filters=[test.b <= Int64(1)]
2546 Projection: test2.a, test2.c
2547 TableScan: test2
2548 "
2549 )
2550 }
2551
2552 #[test]
2555 fn filter_using_left_join() -> Result<()> {
2556 let table_scan = test_table_scan()?;
2557 let left = LogicalPlanBuilder::from(table_scan).build()?;
2558 let right_table_scan = test_table_scan_with_name("test2")?;
2559 let right = LogicalPlanBuilder::from(right_table_scan)
2560 .project(vec![col("a")])?
2561 .build()?;
2562 let plan = LogicalPlanBuilder::from(left)
2563 .join_using(
2564 right,
2565 JoinType::Left,
2566 vec![Column::from_name("a".to_string())],
2567 )?
2568 .filter(col("test2.a").lt_eq(lit(1i64)))?
2569 .build()?;
2570
2571 assert_snapshot!(plan,
2573 @r"
2574 Filter: test2.a <= Int64(1)
2575 Left Join: Using test.a = test2.a
2576 TableScan: test
2577 Projection: test2.a
2578 TableScan: test2
2579 ",
2580 );
2581 assert_optimized_plan_equal!(
2583 plan,
2584 @r"
2585 Filter: test2.a <= Int64(1)
2586 Left Join: Using test.a = test2.a
2587 TableScan: test, full_filters=[test.a <= Int64(1)]
2588 Projection: test2.a
2589 TableScan: test2
2590 "
2591 )
2592 }
2593
2594 #[test]
2596 fn filter_using_right_join() -> Result<()> {
2597 let table_scan = test_table_scan()?;
2598 let left = LogicalPlanBuilder::from(table_scan).build()?;
2599 let right_table_scan = test_table_scan_with_name("test2")?;
2600 let right = LogicalPlanBuilder::from(right_table_scan)
2601 .project(vec![col("a")])?
2602 .build()?;
2603 let plan = LogicalPlanBuilder::from(left)
2604 .join_using(
2605 right,
2606 JoinType::Right,
2607 vec![Column::from_name("a".to_string())],
2608 )?
2609 .filter(col("test.a").lt_eq(lit(1i64)))?
2610 .build()?;
2611
2612 assert_snapshot!(plan,
2614 @r"
2615 Filter: test.a <= Int64(1)
2616 Right Join: Using test.a = test2.a
2617 TableScan: test
2618 Projection: test2.a
2619 TableScan: test2
2620 ",
2621 );
2622 assert_optimized_plan_equal!(
2624 plan,
2625 @r"
2626 Filter: test.a <= Int64(1)
2627 Right Join: Using test.a = test2.a
2628 TableScan: test
2629 Projection: test2.a
2630 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2631 "
2632 )
2633 }
2634
2635 #[test]
2638 fn filter_using_left_join_on_common() -> Result<()> {
2639 let table_scan = test_table_scan()?;
2640 let left = LogicalPlanBuilder::from(table_scan).build()?;
2641 let right_table_scan = test_table_scan_with_name("test2")?;
2642 let right = LogicalPlanBuilder::from(right_table_scan)
2643 .project(vec![col("a")])?
2644 .build()?;
2645 let plan = LogicalPlanBuilder::from(left)
2646 .join_using(
2647 right,
2648 JoinType::Left,
2649 vec![Column::from_name("a".to_string())],
2650 )?
2651 .filter(col("a").lt_eq(lit(1i64)))?
2652 .build()?;
2653
2654 assert_snapshot!(plan,
2656 @r"
2657 Filter: test.a <= Int64(1)
2658 Left Join: Using test.a = test2.a
2659 TableScan: test
2660 Projection: test2.a
2661 TableScan: test2
2662 ",
2663 );
2664 assert_optimized_plan_equal!(
2666 plan,
2667 @r"
2668 Left Join: Using test.a = test2.a
2669 TableScan: test, full_filters=[test.a <= Int64(1)]
2670 Projection: test2.a
2671 TableScan: test2
2672 "
2673 )
2674 }
2675
2676 #[test]
2679 fn filter_using_right_join_on_common() -> Result<()> {
2680 let table_scan = test_table_scan()?;
2681 let left = LogicalPlanBuilder::from(table_scan).build()?;
2682 let right_table_scan = test_table_scan_with_name("test2")?;
2683 let right = LogicalPlanBuilder::from(right_table_scan)
2684 .project(vec![col("a")])?
2685 .build()?;
2686 let plan = LogicalPlanBuilder::from(left)
2687 .join_using(
2688 right,
2689 JoinType::Right,
2690 vec![Column::from_name("a".to_string())],
2691 )?
2692 .filter(col("test2.a").lt_eq(lit(1i64)))?
2693 .build()?;
2694
2695 assert_snapshot!(plan,
2697 @r"
2698 Filter: test2.a <= Int64(1)
2699 Right Join: Using test.a = test2.a
2700 TableScan: test
2701 Projection: test2.a
2702 TableScan: test2
2703 ",
2704 );
2705 assert_optimized_plan_equal!(
2707 plan,
2708 @r"
2709 Right Join: Using test.a = test2.a
2710 TableScan: test
2711 Projection: test2.a
2712 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2713 "
2714 )
2715 }
2716
2717 #[test]
2719 fn join_on_with_filter() -> Result<()> {
2720 let table_scan = test_table_scan()?;
2721 let left = LogicalPlanBuilder::from(table_scan)
2722 .project(vec![col("a"), col("b"), col("c")])?
2723 .build()?;
2724 let right_table_scan = test_table_scan_with_name("test2")?;
2725 let right = LogicalPlanBuilder::from(right_table_scan)
2726 .project(vec![col("a"), col("b"), col("c")])?
2727 .build()?;
2728 let filter = col("test.c")
2729 .gt(lit(1u32))
2730 .and(col("test.b").lt(col("test2.b")))
2731 .and(col("test2.c").gt(lit(4u32)));
2732 let plan = LogicalPlanBuilder::from(left)
2733 .join(
2734 right,
2735 JoinType::Inner,
2736 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2737 Some(filter),
2738 )?
2739 .build()?;
2740
2741 assert_snapshot!(plan,
2743 @r"
2744 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2745 Projection: test.a, test.b, test.c
2746 TableScan: test
2747 Projection: test2.a, test2.b, test2.c
2748 TableScan: test2
2749 ",
2750 );
2751 assert_optimized_plan_equal!(
2752 plan,
2753 @r"
2754 Inner Join: test.a = test2.a Filter: test.b < test2.b
2755 Projection: test.a, test.b, test.c
2756 TableScan: test, full_filters=[test.c > UInt32(1)]
2757 Projection: test2.a, test2.b, test2.c
2758 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2759 "
2760 )
2761 }
2762
2763 #[test]
2765 fn join_filter_removed() -> Result<()> {
2766 let table_scan = test_table_scan()?;
2767 let left = LogicalPlanBuilder::from(table_scan)
2768 .project(vec![col("a"), col("b"), col("c")])?
2769 .build()?;
2770 let right_table_scan = test_table_scan_with_name("test2")?;
2771 let right = LogicalPlanBuilder::from(right_table_scan)
2772 .project(vec![col("a"), col("b"), col("c")])?
2773 .build()?;
2774 let filter = col("test.b")
2775 .gt(lit(1u32))
2776 .and(col("test2.c").gt(lit(4u32)));
2777 let plan = LogicalPlanBuilder::from(left)
2778 .join(
2779 right,
2780 JoinType::Inner,
2781 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2782 Some(filter),
2783 )?
2784 .build()?;
2785
2786 assert_snapshot!(plan,
2788 @r"
2789 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2790 Projection: test.a, test.b, test.c
2791 TableScan: test
2792 Projection: test2.a, test2.b, test2.c
2793 TableScan: test2
2794 ",
2795 );
2796 assert_optimized_plan_equal!(
2797 plan,
2798 @r"
2799 Inner Join: test.a = test2.a
2800 Projection: test.a, test.b, test.c
2801 TableScan: test, full_filters=[test.b > UInt32(1)]
2802 Projection: test2.a, test2.b, test2.c
2803 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2804 "
2805 )
2806 }
2807
2808 #[test]
2810 fn join_filter_on_common() -> Result<()> {
2811 let table_scan = test_table_scan()?;
2812 let left = LogicalPlanBuilder::from(table_scan)
2813 .project(vec![col("a")])?
2814 .build()?;
2815 let right_table_scan = test_table_scan_with_name("test2")?;
2816 let right = LogicalPlanBuilder::from(right_table_scan)
2817 .project(vec![col("b")])?
2818 .build()?;
2819 let filter = col("test.a").gt(lit(1u32));
2820 let plan = LogicalPlanBuilder::from(left)
2821 .join(
2822 right,
2823 JoinType::Inner,
2824 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2825 Some(filter),
2826 )?
2827 .build()?;
2828
2829 assert_snapshot!(plan,
2831 @r"
2832 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2833 Projection: test.a
2834 TableScan: test
2835 Projection: test2.b
2836 TableScan: test2
2837 ",
2838 );
2839 assert_optimized_plan_equal!(
2840 plan,
2841 @r"
2842 Inner Join: test.a = test2.b
2843 Projection: test.a
2844 TableScan: test, full_filters=[test.a > UInt32(1)]
2845 Projection: test2.b
2846 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2847 "
2848 )
2849 }
2850
2851 #[test]
2853 fn left_join_on_with_filter() -> Result<()> {
2854 let table_scan = test_table_scan()?;
2855 let left = LogicalPlanBuilder::from(table_scan)
2856 .project(vec![col("a"), col("b"), col("c")])?
2857 .build()?;
2858 let right_table_scan = test_table_scan_with_name("test2")?;
2859 let right = LogicalPlanBuilder::from(right_table_scan)
2860 .project(vec![col("a"), col("b"), col("c")])?
2861 .build()?;
2862 let filter = col("test.a")
2863 .gt(lit(1u32))
2864 .and(col("test.b").lt(col("test2.b")))
2865 .and(col("test2.c").gt(lit(4u32)));
2866 let plan = LogicalPlanBuilder::from(left)
2867 .join(
2868 right,
2869 JoinType::Left,
2870 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2871 Some(filter),
2872 )?
2873 .build()?;
2874
2875 assert_snapshot!(plan,
2877 @r"
2878 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2879 Projection: test.a, test.b, test.c
2880 TableScan: test
2881 Projection: test2.a, test2.b, test2.c
2882 TableScan: test2
2883 ",
2884 );
2885 assert_optimized_plan_equal!(
2886 plan,
2887 @r"
2888 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2889 Projection: test.a, test.b, test.c
2890 TableScan: test
2891 Projection: test2.a, test2.b, test2.c
2892 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2893 "
2894 )
2895 }
2896
2897 #[test]
2899 fn right_join_on_with_filter() -> Result<()> {
2900 let table_scan = test_table_scan()?;
2901 let left = LogicalPlanBuilder::from(table_scan)
2902 .project(vec![col("a"), col("b"), col("c")])?
2903 .build()?;
2904 let right_table_scan = test_table_scan_with_name("test2")?;
2905 let right = LogicalPlanBuilder::from(right_table_scan)
2906 .project(vec![col("a"), col("b"), col("c")])?
2907 .build()?;
2908 let filter = col("test.a")
2909 .gt(lit(1u32))
2910 .and(col("test.b").lt(col("test2.b")))
2911 .and(col("test2.c").gt(lit(4u32)));
2912 let plan = LogicalPlanBuilder::from(left)
2913 .join(
2914 right,
2915 JoinType::Right,
2916 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2917 Some(filter),
2918 )?
2919 .build()?;
2920
2921 assert_snapshot!(plan,
2923 @r"
2924 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2925 Projection: test.a, test.b, test.c
2926 TableScan: test
2927 Projection: test2.a, test2.b, test2.c
2928 TableScan: test2
2929 ",
2930 );
2931 assert_optimized_plan_equal!(
2932 plan,
2933 @r"
2934 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
2935 Projection: test.a, test.b, test.c
2936 TableScan: test, full_filters=[test.a > UInt32(1)]
2937 Projection: test2.a, test2.b, test2.c
2938 TableScan: test2
2939 "
2940 )
2941 }
2942
2943 #[test]
2945 fn full_join_on_with_filter() -> Result<()> {
2946 let table_scan = test_table_scan()?;
2947 let left = LogicalPlanBuilder::from(table_scan)
2948 .project(vec![col("a"), col("b"), col("c")])?
2949 .build()?;
2950 let right_table_scan = test_table_scan_with_name("test2")?;
2951 let right = LogicalPlanBuilder::from(right_table_scan)
2952 .project(vec![col("a"), col("b"), col("c")])?
2953 .build()?;
2954 let filter = col("test.a")
2955 .gt(lit(1u32))
2956 .and(col("test.b").lt(col("test2.b")))
2957 .and(col("test2.c").gt(lit(4u32)));
2958 let plan = LogicalPlanBuilder::from(left)
2959 .join(
2960 right,
2961 JoinType::Full,
2962 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2963 Some(filter),
2964 )?
2965 .build()?;
2966
2967 assert_snapshot!(plan,
2969 @r"
2970 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2971 Projection: test.a, test.b, test.c
2972 TableScan: test
2973 Projection: test2.a, test2.b, test2.c
2974 TableScan: test2
2975 ",
2976 );
2977 assert_optimized_plan_equal!(
2978 plan,
2979 @r"
2980 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2981 Projection: test.a, test.b, test.c
2982 TableScan: test
2983 Projection: test2.a, test2.b, test2.c
2984 TableScan: test2
2985 "
2986 )
2987 }
2988
2989 struct PushDownProvider {
2990 pub filter_support: TableProviderFilterPushDown,
2991 }
2992
2993 #[async_trait]
2994 impl TableSource for PushDownProvider {
2995 fn schema(&self) -> SchemaRef {
2996 Arc::new(Schema::new(vec![
2997 Field::new("a", DataType::Int32, true),
2998 Field::new("b", DataType::Int32, true),
2999 ]))
3000 }
3001
3002 fn table_type(&self) -> TableType {
3003 TableType::Base
3004 }
3005
3006 fn supports_filters_pushdown(
3007 &self,
3008 filters: &[&Expr],
3009 ) -> Result<Vec<TableProviderFilterPushDown>> {
3010 Ok((0..filters.len())
3011 .map(|_| self.filter_support.clone())
3012 .collect())
3013 }
3014
3015 fn as_any(&self) -> &dyn Any {
3016 self
3017 }
3018 }
3019
3020 fn table_scan_with_pushdown_provider_builder(
3021 filter_support: TableProviderFilterPushDown,
3022 filters: Vec<Expr>,
3023 projection: Option<Vec<usize>>,
3024 ) -> Result<LogicalPlanBuilder> {
3025 let test_provider = PushDownProvider { filter_support };
3026
3027 let table_scan = LogicalPlan::TableScan(TableScan {
3028 table_name: "test".into(),
3029 filters,
3030 projected_schema: Arc::new(DFSchema::try_from(
3031 (*test_provider.schema()).clone(),
3032 )?),
3033 projection,
3034 source: Arc::new(test_provider),
3035 fetch: None,
3036 });
3037
3038 Ok(LogicalPlanBuilder::from(table_scan))
3039 }
3040
3041 fn table_scan_with_pushdown_provider(
3042 filter_support: TableProviderFilterPushDown,
3043 ) -> Result<LogicalPlan> {
3044 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3045 .filter(col("a").eq(lit(1i64)))?
3046 .build()
3047 }
3048
3049 #[test]
3050 fn filter_with_table_provider_exact() -> Result<()> {
3051 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3052
3053 assert_optimized_plan_equal!(
3054 plan,
3055 @"TableScan: test, full_filters=[a = Int64(1)]"
3056 )
3057 }
3058
3059 #[test]
3060 fn filter_with_table_provider_inexact() -> Result<()> {
3061 let plan =
3062 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3063
3064 assert_optimized_plan_equal!(
3065 plan,
3066 @r"
3067 Filter: a = Int64(1)
3068 TableScan: test, partial_filters=[a = Int64(1)]
3069 "
3070 )
3071 }
3072
3073 #[test]
3074 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3075 let plan =
3076 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3077
3078 let optimized_plan = PushDownFilter::new()
3079 .rewrite(plan, &OptimizerContext::new())
3080 .expect("failed to optimize plan")
3081 .data;
3082
3083 assert_optimized_plan_equal!(
3086 optimized_plan,
3087 @r"
3088 Filter: a = Int64(1)
3089 TableScan: test, partial_filters=[a = Int64(1)]
3090 "
3091 )
3092 }
3093
3094 #[test]
3095 fn filter_with_table_provider_unsupported() -> Result<()> {
3096 let plan =
3097 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3098
3099 assert_optimized_plan_equal!(
3100 plan,
3101 @r"
3102 Filter: a = Int64(1)
3103 TableScan: test
3104 "
3105 )
3106 }
3107
3108 #[test]
3109 fn multi_combined_filter() -> Result<()> {
3110 let plan = table_scan_with_pushdown_provider_builder(
3111 TableProviderFilterPushDown::Inexact,
3112 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3113 Some(vec![0]),
3114 )?
3115 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3116 .project(vec![col("a"), col("b")])?
3117 .build()?;
3118
3119 assert_optimized_plan_equal!(
3120 plan,
3121 @r"
3122 Projection: a, b
3123 Filter: a = Int64(10) AND b > Int64(11)
3124 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3125 "
3126 )
3127 }
3128
3129 #[test]
3130 fn multi_combined_filter_exact() -> Result<()> {
3131 let plan = table_scan_with_pushdown_provider_builder(
3132 TableProviderFilterPushDown::Exact,
3133 vec![],
3134 Some(vec![0]),
3135 )?
3136 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3137 .project(vec![col("a"), col("b")])?
3138 .build()?;
3139
3140 assert_optimized_plan_equal!(
3141 plan,
3142 @r"
3143 Projection: a, b
3144 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3145 "
3146 )
3147 }
3148
3149 #[test]
3150 fn test_filter_with_alias() -> Result<()> {
3151 let table_scan = test_table_scan()?;
3155 let plan = LogicalPlanBuilder::from(table_scan)
3156 .project(vec![col("a").alias("b"), col("c")])?
3157 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3158 .build()?;
3159
3160 assert_snapshot!(plan,
3162 @r"
3163 Filter: b > Int64(10) AND test.c > Int64(10)
3164 Projection: test.a AS b, test.c
3165 TableScan: test
3166 ",
3167 );
3168 assert_optimized_plan_equal!(
3170 plan,
3171 @r"
3172 Projection: test.a AS b, test.c
3173 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3174 "
3175 )
3176 }
3177
3178 #[test]
3179 fn test_filter_with_alias_2() -> Result<()> {
3180 let table_scan = test_table_scan()?;
3184 let plan = LogicalPlanBuilder::from(table_scan)
3185 .project(vec![col("a").alias("b"), col("c")])?
3186 .project(vec![col("b"), col("c")])?
3187 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3188 .build()?;
3189
3190 assert_snapshot!(plan,
3192 @r"
3193 Filter: b > Int64(10) AND test.c > Int64(10)
3194 Projection: b, test.c
3195 Projection: test.a AS b, test.c
3196 TableScan: test
3197 ",
3198 );
3199 assert_optimized_plan_equal!(
3201 plan,
3202 @r"
3203 Projection: b, test.c
3204 Projection: test.a AS b, test.c
3205 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3206 "
3207 )
3208 }
3209
3210 #[test]
3211 fn test_filter_with_multi_alias() -> Result<()> {
3212 let table_scan = test_table_scan()?;
3213 let plan = LogicalPlanBuilder::from(table_scan)
3214 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3215 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3216 .build()?;
3217
3218 assert_snapshot!(plan,
3220 @r"
3221 Filter: b > Int64(10) AND d > Int64(10)
3222 Projection: test.a AS b, test.c AS d
3223 TableScan: test
3224 ",
3225 );
3226 assert_optimized_plan_equal!(
3228 plan,
3229 @r"
3230 Projection: test.a AS b, test.c AS d
3231 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3232 "
3233 )
3234 }
3235
3236 #[test]
3238 fn join_filter_with_alias() -> Result<()> {
3239 let table_scan = test_table_scan()?;
3240 let left = LogicalPlanBuilder::from(table_scan)
3241 .project(vec![col("a").alias("c")])?
3242 .build()?;
3243 let right_table_scan = test_table_scan_with_name("test2")?;
3244 let right = LogicalPlanBuilder::from(right_table_scan)
3245 .project(vec![col("b").alias("d")])?
3246 .build()?;
3247 let filter = col("c").gt(lit(1u32));
3248 let plan = LogicalPlanBuilder::from(left)
3249 .join(
3250 right,
3251 JoinType::Inner,
3252 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3253 Some(filter),
3254 )?
3255 .build()?;
3256
3257 assert_snapshot!(plan,
3258 @r"
3259 Inner Join: c = d Filter: c > UInt32(1)
3260 Projection: test.a AS c
3261 TableScan: test
3262 Projection: test2.b AS d
3263 TableScan: test2
3264 ",
3265 );
3266 assert_optimized_plan_equal!(
3268 plan,
3269 @r"
3270 Inner Join: c = d
3271 Projection: test.a AS c
3272 TableScan: test, full_filters=[test.a > UInt32(1)]
3273 Projection: test2.b AS d
3274 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3275 "
3276 )
3277 }
3278
3279 #[test]
3280 fn test_in_filter_with_alias() -> Result<()> {
3281 let table_scan = test_table_scan()?;
3285 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3286 let plan = LogicalPlanBuilder::from(table_scan)
3287 .project(vec![col("a").alias("b"), col("c")])?
3288 .filter(in_list(col("b"), filter_value, false))?
3289 .build()?;
3290
3291 assert_snapshot!(plan,
3293 @r"
3294 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3295 Projection: test.a AS b, test.c
3296 TableScan: test
3297 ",
3298 );
3299 assert_optimized_plan_equal!(
3301 plan,
3302 @r"
3303 Projection: test.a AS b, test.c
3304 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3305 "
3306 )
3307 }
3308
3309 #[test]
3310 fn test_in_filter_with_alias_2() -> Result<()> {
3311 let table_scan = test_table_scan()?;
3315 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3316 let plan = LogicalPlanBuilder::from(table_scan)
3317 .project(vec![col("a").alias("b"), col("c")])?
3318 .project(vec![col("b"), col("c")])?
3319 .filter(in_list(col("b"), filter_value, false))?
3320 .build()?;
3321
3322 assert_snapshot!(plan,
3324 @r"
3325 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3326 Projection: b, test.c
3327 Projection: test.a AS b, test.c
3328 TableScan: test
3329 ",
3330 );
3331 assert_optimized_plan_equal!(
3333 plan,
3334 @r"
3335 Projection: b, test.c
3336 Projection: test.a AS b, test.c
3337 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3338 "
3339 )
3340 }
3341
3342 #[test]
3343 fn test_in_subquery_with_alias() -> Result<()> {
3344 let table_scan = test_table_scan()?;
3347 let table_scan_sq = test_table_scan_with_name("sq")?;
3348 let subplan = Arc::new(
3349 LogicalPlanBuilder::from(table_scan_sq)
3350 .project(vec![col("c")])?
3351 .build()?,
3352 );
3353 let plan = LogicalPlanBuilder::from(table_scan)
3354 .project(vec![col("a").alias("b"), col("c")])?
3355 .filter(in_subquery(col("b"), subplan))?
3356 .build()?;
3357
3358 assert_snapshot!(plan,
3360 @r"
3361 Filter: b IN (<subquery>)
3362 Subquery:
3363 Projection: sq.c
3364 TableScan: sq
3365 Projection: test.a AS b, test.c
3366 TableScan: test
3367 ",
3368 );
3369 assert_optimized_plan_equal!(
3371 plan,
3372 @r"
3373 Projection: test.a AS b, test.c
3374 TableScan: test, full_filters=[test.a IN (<subquery>)]
3375 Subquery:
3376 Projection: sq.c
3377 TableScan: sq
3378 "
3379 )
3380 }
3381
3382 #[test]
3383 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3384 let plan = LogicalPlanBuilder::empty(true)
3386 .project(vec![lit(0i64).alias("a")])?
3387 .alias("b")?
3388 .project(vec![col("b.a")])?
3389 .alias("b")?
3390 .filter(col("b.a").eq(lit(1i64)))?
3391 .project(vec![col("b.a")])?
3392 .build()?;
3393
3394 assert_snapshot!(plan,
3395 @r"
3396 Projection: b.a
3397 Filter: b.a = Int64(1)
3398 SubqueryAlias: b
3399 Projection: b.a
3400 SubqueryAlias: b
3401 Projection: Int64(0) AS a
3402 EmptyRelation
3403 ",
3404 );
3405 assert_optimized_plan_equal!(
3408 plan,
3409 @r"
3410 Projection: b.a
3411 SubqueryAlias: b
3412 Projection: b.a
3413 SubqueryAlias: b
3414 Projection: Int64(0) AS a
3415 Filter: Int64(0) = Int64(1)
3416 EmptyRelation
3417 "
3418 )
3419 }
3420
3421 #[test]
3422 fn test_crossjoin_with_or_clause() -> Result<()> {
3423 let table_scan = test_table_scan()?;
3425 let left = LogicalPlanBuilder::from(table_scan)
3426 .project(vec![col("a"), col("b"), col("c")])?
3427 .build()?;
3428 let right_table_scan = test_table_scan_with_name("test1")?;
3429 let right = LogicalPlanBuilder::from(right_table_scan)
3430 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3431 .build()?;
3432 let filter = or(
3433 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3434 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3435 );
3436 let plan = LogicalPlanBuilder::from(left)
3437 .cross_join(right)?
3438 .filter(filter)?
3439 .build()?;
3440
3441 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3442 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3443 Projection: test.a, test.b, test.c
3444 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3445 Projection: test1.a AS d, test1.a AS e
3446 TableScan: test1
3447 ")?;
3448
3449 let optimized_plan = PushDownFilter::new()
3452 .rewrite(plan, &OptimizerContext::new())
3453 .expect("failed to optimize plan")
3454 .data;
3455 assert_optimized_plan_equal!(
3456 optimized_plan,
3457 @r"
3458 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3459 Projection: test.a, test.b, test.c
3460 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3461 Projection: test1.a AS d, test1.a AS e
3462 TableScan: test1
3463 "
3464 )
3465 }
3466
3467 #[test]
3468 fn left_semi_join() -> Result<()> {
3469 let left = test_table_scan_with_name("test1")?;
3470 let right_table_scan = test_table_scan_with_name("test2")?;
3471 let right = LogicalPlanBuilder::from(right_table_scan)
3472 .project(vec![col("a"), col("b")])?
3473 .build()?;
3474 let plan = LogicalPlanBuilder::from(left)
3475 .join(
3476 right,
3477 JoinType::LeftSemi,
3478 (
3479 vec![Column::from_qualified_name("test1.a")],
3480 vec![Column::from_qualified_name("test2.a")],
3481 ),
3482 None,
3483 )?
3484 .filter(col("test2.a").lt_eq(lit(1i64)))?
3485 .build()?;
3486
3487 assert_snapshot!(plan,
3489 @r"
3490 Filter: test2.a <= Int64(1)
3491 LeftSemi Join: test1.a = test2.a
3492 TableScan: test1
3493 Projection: test2.a, test2.b
3494 TableScan: test2
3495 ",
3496 );
3497 assert_optimized_plan_equal!(
3499 plan,
3500 @r"
3501 Filter: test2.a <= Int64(1)
3502 LeftSemi Join: test1.a = test2.a
3503 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3504 Projection: test2.a, test2.b
3505 TableScan: test2
3506 "
3507 )
3508 }
3509
3510 #[test]
3511 fn left_semi_join_with_filters() -> Result<()> {
3512 let left = test_table_scan_with_name("test1")?;
3513 let right_table_scan = test_table_scan_with_name("test2")?;
3514 let right = LogicalPlanBuilder::from(right_table_scan)
3515 .project(vec![col("a"), col("b")])?
3516 .build()?;
3517 let plan = LogicalPlanBuilder::from(left)
3518 .join(
3519 right,
3520 JoinType::LeftSemi,
3521 (
3522 vec![Column::from_qualified_name("test1.a")],
3523 vec![Column::from_qualified_name("test2.a")],
3524 ),
3525 Some(
3526 col("test1.b")
3527 .gt(lit(1u32))
3528 .and(col("test2.b").gt(lit(2u32))),
3529 ),
3530 )?
3531 .build()?;
3532
3533 assert_snapshot!(plan,
3535 @r"
3536 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3537 TableScan: test1
3538 Projection: test2.a, test2.b
3539 TableScan: test2
3540 ",
3541 );
3542 assert_optimized_plan_equal!(
3544 plan,
3545 @r"
3546 LeftSemi Join: test1.a = test2.a
3547 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3548 Projection: test2.a, test2.b
3549 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3550 "
3551 )
3552 }
3553
3554 #[test]
3555 fn right_semi_join() -> Result<()> {
3556 let left = test_table_scan_with_name("test1")?;
3557 let right_table_scan = test_table_scan_with_name("test2")?;
3558 let right = LogicalPlanBuilder::from(right_table_scan)
3559 .project(vec![col("a"), col("b")])?
3560 .build()?;
3561 let plan = LogicalPlanBuilder::from(left)
3562 .join(
3563 right,
3564 JoinType::RightSemi,
3565 (
3566 vec![Column::from_qualified_name("test1.a")],
3567 vec![Column::from_qualified_name("test2.a")],
3568 ),
3569 None,
3570 )?
3571 .filter(col("test1.a").lt_eq(lit(1i64)))?
3572 .build()?;
3573
3574 assert_snapshot!(plan,
3576 @r"
3577 Filter: test1.a <= Int64(1)
3578 RightSemi Join: test1.a = test2.a
3579 TableScan: test1
3580 Projection: test2.a, test2.b
3581 TableScan: test2
3582 ",
3583 );
3584 assert_optimized_plan_equal!(
3586 plan,
3587 @r"
3588 Filter: test1.a <= Int64(1)
3589 RightSemi Join: test1.a = test2.a
3590 TableScan: test1
3591 Projection: test2.a, test2.b
3592 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3593 "
3594 )
3595 }
3596
3597 #[test]
3598 fn right_semi_join_with_filters() -> Result<()> {
3599 let left = test_table_scan_with_name("test1")?;
3600 let right_table_scan = test_table_scan_with_name("test2")?;
3601 let right = LogicalPlanBuilder::from(right_table_scan)
3602 .project(vec![col("a"), col("b")])?
3603 .build()?;
3604 let plan = LogicalPlanBuilder::from(left)
3605 .join(
3606 right,
3607 JoinType::RightSemi,
3608 (
3609 vec![Column::from_qualified_name("test1.a")],
3610 vec![Column::from_qualified_name("test2.a")],
3611 ),
3612 Some(
3613 col("test1.b")
3614 .gt(lit(1u32))
3615 .and(col("test2.b").gt(lit(2u32))),
3616 ),
3617 )?
3618 .build()?;
3619
3620 assert_snapshot!(plan,
3622 @r"
3623 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3624 TableScan: test1
3625 Projection: test2.a, test2.b
3626 TableScan: test2
3627 ",
3628 );
3629 assert_optimized_plan_equal!(
3631 plan,
3632 @r"
3633 RightSemi Join: test1.a = test2.a
3634 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3635 Projection: test2.a, test2.b
3636 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3637 "
3638 )
3639 }
3640
3641 #[test]
3642 fn left_anti_join() -> Result<()> {
3643 let table_scan = test_table_scan_with_name("test1")?;
3644 let left = LogicalPlanBuilder::from(table_scan)
3645 .project(vec![col("a"), col("b")])?
3646 .build()?;
3647 let right_table_scan = test_table_scan_with_name("test2")?;
3648 let right = LogicalPlanBuilder::from(right_table_scan)
3649 .project(vec![col("a"), col("b")])?
3650 .build()?;
3651 let plan = LogicalPlanBuilder::from(left)
3652 .join(
3653 right,
3654 JoinType::LeftAnti,
3655 (
3656 vec![Column::from_qualified_name("test1.a")],
3657 vec![Column::from_qualified_name("test2.a")],
3658 ),
3659 None,
3660 )?
3661 .filter(col("test2.a").gt(lit(2u32)))?
3662 .build()?;
3663
3664 assert_snapshot!(plan,
3666 @r"
3667 Filter: test2.a > UInt32(2)
3668 LeftAnti Join: test1.a = test2.a
3669 Projection: test1.a, test1.b
3670 TableScan: test1
3671 Projection: test2.a, test2.b
3672 TableScan: test2
3673 ",
3674 );
3675 assert_optimized_plan_equal!(
3677 plan,
3678 @r"
3679 Filter: test2.a > UInt32(2)
3680 LeftAnti Join: test1.a = test2.a
3681 Projection: test1.a, test1.b
3682 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3683 Projection: test2.a, test2.b
3684 TableScan: test2
3685 "
3686 )
3687 }
3688
3689 #[test]
3690 fn left_anti_join_with_filters() -> Result<()> {
3691 let table_scan = test_table_scan_with_name("test1")?;
3692 let left = LogicalPlanBuilder::from(table_scan)
3693 .project(vec![col("a"), col("b")])?
3694 .build()?;
3695 let right_table_scan = test_table_scan_with_name("test2")?;
3696 let right = LogicalPlanBuilder::from(right_table_scan)
3697 .project(vec![col("a"), col("b")])?
3698 .build()?;
3699 let plan = LogicalPlanBuilder::from(left)
3700 .join(
3701 right,
3702 JoinType::LeftAnti,
3703 (
3704 vec![Column::from_qualified_name("test1.a")],
3705 vec![Column::from_qualified_name("test2.a")],
3706 ),
3707 Some(
3708 col("test1.b")
3709 .gt(lit(1u32))
3710 .and(col("test2.b").gt(lit(2u32))),
3711 ),
3712 )?
3713 .build()?;
3714
3715 assert_snapshot!(plan,
3717 @r"
3718 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3719 Projection: test1.a, test1.b
3720 TableScan: test1
3721 Projection: test2.a, test2.b
3722 TableScan: test2
3723 ",
3724 );
3725 assert_optimized_plan_equal!(
3727 plan,
3728 @r"
3729 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3730 Projection: test1.a, test1.b
3731 TableScan: test1
3732 Projection: test2.a, test2.b
3733 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3734 "
3735 )
3736 }
3737
3738 #[test]
3739 fn right_anti_join() -> Result<()> {
3740 let table_scan = test_table_scan_with_name("test1")?;
3741 let left = LogicalPlanBuilder::from(table_scan)
3742 .project(vec![col("a"), col("b")])?
3743 .build()?;
3744 let right_table_scan = test_table_scan_with_name("test2")?;
3745 let right = LogicalPlanBuilder::from(right_table_scan)
3746 .project(vec![col("a"), col("b")])?
3747 .build()?;
3748 let plan = LogicalPlanBuilder::from(left)
3749 .join(
3750 right,
3751 JoinType::RightAnti,
3752 (
3753 vec![Column::from_qualified_name("test1.a")],
3754 vec![Column::from_qualified_name("test2.a")],
3755 ),
3756 None,
3757 )?
3758 .filter(col("test1.a").gt(lit(2u32)))?
3759 .build()?;
3760
3761 assert_snapshot!(plan,
3763 @r"
3764 Filter: test1.a > UInt32(2)
3765 RightAnti Join: test1.a = test2.a
3766 Projection: test1.a, test1.b
3767 TableScan: test1
3768 Projection: test2.a, test2.b
3769 TableScan: test2
3770 ",
3771 );
3772 assert_optimized_plan_equal!(
3774 plan,
3775 @r"
3776 Filter: test1.a > UInt32(2)
3777 RightAnti Join: test1.a = test2.a
3778 Projection: test1.a, test1.b
3779 TableScan: test1
3780 Projection: test2.a, test2.b
3781 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3782 "
3783 )
3784 }
3785
3786 #[test]
3787 fn right_anti_join_with_filters() -> Result<()> {
3788 let table_scan = test_table_scan_with_name("test1")?;
3789 let left = LogicalPlanBuilder::from(table_scan)
3790 .project(vec![col("a"), col("b")])?
3791 .build()?;
3792 let right_table_scan = test_table_scan_with_name("test2")?;
3793 let right = LogicalPlanBuilder::from(right_table_scan)
3794 .project(vec![col("a"), col("b")])?
3795 .build()?;
3796 let plan = LogicalPlanBuilder::from(left)
3797 .join(
3798 right,
3799 JoinType::RightAnti,
3800 (
3801 vec![Column::from_qualified_name("test1.a")],
3802 vec![Column::from_qualified_name("test2.a")],
3803 ),
3804 Some(
3805 col("test1.b")
3806 .gt(lit(1u32))
3807 .and(col("test2.b").gt(lit(2u32))),
3808 ),
3809 )?
3810 .build()?;
3811
3812 assert_snapshot!(plan,
3814 @r"
3815 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3816 Projection: test1.a, test1.b
3817 TableScan: test1
3818 Projection: test2.a, test2.b
3819 TableScan: test2
3820 ",
3821 );
3822 assert_optimized_plan_equal!(
3824 plan,
3825 @r"
3826 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3827 Projection: test1.a, test1.b
3828 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3829 Projection: test2.a, test2.b
3830 TableScan: test2
3831 "
3832 )
3833 }
3834
3835 #[derive(Debug)]
3836 struct TestScalarUDF {
3837 signature: Signature,
3838 }
3839
3840 impl ScalarUDFImpl for TestScalarUDF {
3841 fn as_any(&self) -> &dyn Any {
3842 self
3843 }
3844 fn name(&self) -> &str {
3845 "TestScalarUDF"
3846 }
3847
3848 fn signature(&self) -> &Signature {
3849 &self.signature
3850 }
3851
3852 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3853 Ok(DataType::Int32)
3854 }
3855
3856 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3857 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3858 }
3859 }
3860
3861 #[test]
3862 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3863 let table_scan = test_table_scan_with_name("test1")?;
3865 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3866 signature: Signature::exact(vec![], Volatility::Volatile),
3867 });
3868 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3869
3870 let plan = LogicalPlanBuilder::from(table_scan)
3871 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3872 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3873 .alias("t")?
3874 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3875 .project(vec![col("t.a"), col("t.r")])?
3876 .build()?;
3877
3878 assert_snapshot!(plan,
3879 @r"
3880 Projection: t.a, t.r
3881 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3882 SubqueryAlias: t
3883 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3884 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3885 TableScan: test1
3886 ",
3887 );
3888 assert_optimized_plan_equal!(
3889 plan,
3890 @r"
3891 Projection: t.a, t.r
3892 SubqueryAlias: t
3893 Filter: r > Float64(0.5)
3894 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3895 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3896 TableScan: test1, full_filters=[test1.a > Int32(5)]
3897 "
3898 )
3899 }
3900
3901 #[test]
3902 fn test_push_down_volatile_function_in_join() -> Result<()> {
3903 let table_scan = test_table_scan_with_name("test1")?;
3905 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3906 signature: Signature::exact(vec![], Volatility::Volatile),
3907 });
3908 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3909 let left = LogicalPlanBuilder::from(table_scan).build()?;
3910 let right_table_scan = test_table_scan_with_name("test2")?;
3911 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
3912 let plan = LogicalPlanBuilder::from(left)
3913 .join(
3914 right,
3915 JoinType::Inner,
3916 (
3917 vec![Column::from_qualified_name("test1.a")],
3918 vec![Column::from_qualified_name("test2.a")],
3919 ),
3920 None,
3921 )?
3922 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
3923 .alias("t")?
3924 .filter(col("t.r").gt(lit(0.8)))?
3925 .project(vec![col("t.a"), col("t.r")])?
3926 .build()?;
3927
3928 assert_snapshot!(plan,
3929 @r"
3930 Projection: t.a, t.r
3931 Filter: t.r > Float64(0.8)
3932 SubqueryAlias: t
3933 Projection: test1.a AS a, TestScalarUDF() AS r
3934 Inner Join: test1.a = test2.a
3935 TableScan: test1
3936 TableScan: test2
3937 ",
3938 );
3939 assert_optimized_plan_equal!(
3940 plan,
3941 @r"
3942 Projection: t.a, t.r
3943 SubqueryAlias: t
3944 Filter: r > Float64(0.8)
3945 Projection: test1.a AS a, TestScalarUDF() AS r
3946 Inner Join: test1.a = test2.a
3947 TableScan: test1
3948 TableScan: test2
3949 "
3950 )
3951 }
3952
3953 #[test]
3954 fn test_push_down_volatile_table_scan() -> Result<()> {
3955 let table_scan = test_table_scan()?;
3957 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3958 signature: Signature::exact(vec![], Volatility::Volatile),
3959 });
3960 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3961 let plan = LogicalPlanBuilder::from(table_scan)
3962 .project(vec![col("a"), col("b")])?
3963 .filter(expr.gt(lit(0.1)))?
3964 .build()?;
3965
3966 assert_snapshot!(plan,
3967 @r"
3968 Filter: TestScalarUDF() > Float64(0.1)
3969 Projection: test.a, test.b
3970 TableScan: test
3971 ",
3972 );
3973 assert_optimized_plan_equal!(
3974 plan,
3975 @r"
3976 Projection: test.a, test.b
3977 Filter: TestScalarUDF() > Float64(0.1)
3978 TableScan: test
3979 "
3980 )
3981 }
3982
3983 #[test]
3984 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
3985 let table_scan = test_table_scan()?;
3987 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3988 signature: Signature::exact(vec![], Volatility::Volatile),
3989 });
3990 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3991 let plan = LogicalPlanBuilder::from(table_scan)
3992 .project(vec![col("a"), col("b")])?
3993 .filter(
3994 expr.gt(lit(0.1))
3995 .and(col("t.a").gt(lit(5)))
3996 .and(col("t.b").gt(lit(10))),
3997 )?
3998 .build()?;
3999
4000 assert_snapshot!(plan,
4001 @r"
4002 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4003 Projection: test.a, test.b
4004 TableScan: test
4005 ",
4006 );
4007 assert_optimized_plan_equal!(
4008 plan,
4009 @r"
4010 Projection: test.a, test.b
4011 Filter: TestScalarUDF() > Float64(0.1)
4012 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4013 "
4014 )
4015 }
4016
4017 #[test]
4018 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4019 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4021 signature: Signature::exact(vec![], Volatility::Volatile),
4022 });
4023 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4024 let plan = table_scan_with_pushdown_provider_builder(
4025 TableProviderFilterPushDown::Unsupported,
4026 vec![],
4027 None,
4028 )?
4029 .project(vec![col("a"), col("b")])?
4030 .filter(
4031 expr.gt(lit(0.1))
4032 .and(col("t.a").gt(lit(5)))
4033 .and(col("t.b").gt(lit(10))),
4034 )?
4035 .build()?;
4036
4037 assert_snapshot!(plan,
4038 @r"
4039 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4040 Projection: a, b
4041 TableScan: test
4042 ",
4043 );
4044 assert_optimized_plan_equal!(
4045 plan,
4046 @r"
4047 Projection: a, b
4048 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4049 TableScan: test
4050 "
4051 )
4052 }
4053
4054 #[test]
4055 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4056 #[derive(Debug, Hash, Eq, PartialEq)]
4058 struct TestUserNode {
4059 schema: DFSchemaRef,
4060 }
4061
4062 impl PartialOrd for TestUserNode {
4063 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4064 None
4065 }
4066 }
4067
4068 impl TestUserNode {
4069 fn new() -> Self {
4070 let schema = Arc::new(
4071 DFSchema::new_with_metadata(
4072 vec![(None, Field::new("a", DataType::Int64, false).into())],
4073 Default::default(),
4074 )
4075 .unwrap(),
4076 );
4077
4078 Self { schema }
4079 }
4080 }
4081
4082 impl UserDefinedLogicalNodeCore for TestUserNode {
4083 fn name(&self) -> &str {
4084 "test_node"
4085 }
4086
4087 fn inputs(&self) -> Vec<&LogicalPlan> {
4088 vec![]
4089 }
4090
4091 fn schema(&self) -> &DFSchemaRef {
4092 &self.schema
4093 }
4094
4095 fn expressions(&self) -> Vec<Expr> {
4096 vec![]
4097 }
4098
4099 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4100 write!(f, "TestUserNode")
4101 }
4102
4103 fn with_exprs_and_inputs(
4104 &self,
4105 exprs: Vec<Expr>,
4106 inputs: Vec<LogicalPlan>,
4107 ) -> Result<Self> {
4108 assert!(exprs.is_empty());
4109 assert!(inputs.is_empty());
4110 Ok(Self {
4111 schema: Arc::clone(&self.schema),
4112 })
4113 }
4114 }
4115
4116 let node = LogicalPlan::Extension(Extension {
4118 node: Arc::new(TestUserNode::new()),
4119 });
4120
4121 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4122
4123 assert_snapshot!(plan,
4125 @r"
4126 Filter: Boolean(false)
4127 TestUserNode
4128 ",
4129 );
4130 assert_optimized_plan_equal!(
4132 plan,
4133 @r"
4134 Filter: Boolean(false)
4135 TestUserNode
4136 "
4137 )
4138 }
4139}