1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use datafusion::logical_expr::logical_plan::LogicalPlan;
24use datafusion::logical_expr::{
25 BinaryExpr, Expr, Extension, Filter, Operator as DfOperator, UserDefinedLogicalNodeCore,
26};
27use datafusion_common::tree_node::Transformed;
28use datafusion_common::Result;
29use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
30
31use crate::datafusion::lookup_join::{LookupJoinNode, LookupJoinType};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum PredicateClass {
40 LookupOnly,
42 StreamOnly,
44 CrossReference,
46 Constant,
48}
49
50#[derive(Debug)]
56pub struct PredicateClassifier {
57 lookup_columns: HashSet<String>,
59 stream_columns: HashSet<String>,
61 lookup_qualified: HashSet<String>,
63 stream_qualified: HashSet<String>,
65}
66
67impl PredicateClassifier {
68 #[must_use]
74 pub fn new(
75 lookup_columns: HashSet<String>,
76 stream_columns: HashSet<String>,
77 lookup_alias: Option<&str>,
78 stream_alias: Option<&str>,
79 ) -> Self {
80 let mut lookup_qualified = HashSet::new();
81 let mut stream_qualified = HashSet::new();
82
83 if let Some(alias) = lookup_alias {
84 for col in &lookup_columns {
85 lookup_qualified.insert(format!("{alias}.{col}"));
86 }
87 }
88 if let Some(alias) = stream_alias {
89 for col in &stream_columns {
90 stream_qualified.insert(format!("{alias}.{col}"));
91 }
92 }
93
94 Self {
95 lookup_columns,
96 stream_columns,
97 lookup_qualified,
98 stream_qualified,
99 }
100 }
101
102 #[must_use]
104 pub fn classify(&self, expr: &Expr) -> PredicateClass {
105 let mut has_lookup = false;
106 let mut has_stream = false;
107 self.walk_columns(expr, &mut has_lookup, &mut has_stream);
108
109 match (has_lookup, has_stream) {
110 (true, false) => PredicateClass::LookupOnly,
111 (false, true) => PredicateClass::StreamOnly,
112 (true, true) => PredicateClass::CrossReference,
113 (false, false) => PredicateClass::Constant,
114 }
115 }
116
117 fn walk_columns(&self, expr: &Expr, has_lookup: &mut bool, has_stream: &mut bool) {
119 match expr {
120 Expr::Column(col) => {
121 if let Some(relation) = &col.relation {
123 let qualified = format!("{}.{}", relation, col.name);
124 if self.lookup_qualified.contains(&qualified) {
125 *has_lookup = true;
126 return;
127 }
128 if self.stream_qualified.contains(&qualified) {
129 *has_stream = true;
130 return;
131 }
132 }
133 if self.lookup_columns.contains(&col.name) {
135 *has_lookup = true;
136 }
137 if self.stream_columns.contains(&col.name) {
138 *has_stream = true;
139 }
140 }
141 Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
142 self.walk_columns(left, has_lookup, has_stream);
143 self.walk_columns(right, has_lookup, has_stream);
144 }
145 Expr::Not(inner)
146 | Expr::IsNull(inner)
147 | Expr::IsNotNull(inner)
148 | Expr::Negative(inner)
149 | Expr::Cast(datafusion::logical_expr::Cast { expr: inner, .. })
150 | Expr::TryCast(datafusion::logical_expr::TryCast { expr: inner, .. }) => {
151 self.walk_columns(inner, has_lookup, has_stream);
152 }
153 Expr::Between(between) => {
154 self.walk_columns(&between.expr, has_lookup, has_stream);
155 self.walk_columns(&between.low, has_lookup, has_stream);
156 self.walk_columns(&between.high, has_lookup, has_stream);
157 }
158 Expr::InList(in_list) => {
159 self.walk_columns(&in_list.expr, has_lookup, has_stream);
160 for item in &in_list.list {
161 self.walk_columns(item, has_lookup, has_stream);
162 }
163 }
164 Expr::ScalarFunction(func) => {
165 for arg in &func.args {
166 self.walk_columns(arg, has_lookup, has_stream);
167 }
168 }
169 Expr::Like(like) => {
170 self.walk_columns(&like.expr, has_lookup, has_stream);
171 self.walk_columns(&like.pattern, has_lookup, has_stream);
172 }
173 Expr::Case(case) => {
174 if let Some(operand) = &case.expr {
175 self.walk_columns(operand, has_lookup, has_stream);
176 }
177 for (when, then) in &case.when_then_expr {
178 self.walk_columns(when, has_lookup, has_stream);
179 self.walk_columns(then, has_lookup, has_stream);
180 }
181 if let Some(else_expr) = &case.else_expr {
182 self.walk_columns(else_expr, has_lookup, has_stream);
183 }
184 }
185 Expr::Literal(..) | Expr::Placeholder(_) => {}
187 _ => {
189 *has_lookup = true;
190 *has_stream = true;
191 }
192 }
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
202pub enum PlanPushdownMode {
203 Full,
205 KeyOnly,
207 None,
209}
210
211#[derive(Debug, Clone)]
213pub struct PlanSourceCapabilities {
214 pub pushdown_mode: PlanPushdownMode,
216 pub eq_columns: HashSet<String>,
218 pub range_columns: HashSet<String>,
220 pub in_columns: HashSet<String>,
222 pub supports_null_check: bool,
224}
225
226impl Default for PlanSourceCapabilities {
227 fn default() -> Self {
228 Self {
229 pushdown_mode: PlanPushdownMode::None,
230 eq_columns: HashSet::new(),
231 range_columns: HashSet::new(),
232 in_columns: HashSet::new(),
233 supports_null_check: false,
234 }
235 }
236}
237
238#[derive(Debug, Default)]
240pub struct SourceCapabilitiesRegistry {
241 capabilities: HashMap<String, PlanSourceCapabilities>,
242}
243
244impl SourceCapabilitiesRegistry {
245 pub fn register(&mut self, table_name: String, caps: PlanSourceCapabilities) {
247 self.capabilities.insert(table_name, caps);
248 }
249
250 #[must_use]
252 pub fn get(&self, table_name: &str) -> Option<&PlanSourceCapabilities> {
253 self.capabilities.get(table_name)
254 }
255}
256
257#[must_use]
266pub fn split_conjunction(expr: &Expr) -> Vec<Expr> {
267 match expr {
268 Expr::BinaryExpr(BinaryExpr {
269 left,
270 op: DfOperator::And,
271 right,
272 }) => {
273 let mut parts = split_conjunction(left);
274 parts.extend(split_conjunction(right));
275 parts
276 }
277 other => vec![other.clone()],
278 }
279}
280
281#[derive(Debug)]
295pub struct PredicateSplitterRule {
296 capabilities: SourceCapabilitiesRegistry,
298}
299
300impl PredicateSplitterRule {
301 #[must_use]
303 pub fn new(capabilities: SourceCapabilitiesRegistry) -> Self {
304 Self { capabilities }
305 }
306
307 fn split_for_node(
311 &self,
312 node: &LookupJoinNode,
313 filter_predicates: &[Expr],
314 ) -> (Vec<Expr>, Vec<Expr>) {
315 let lookup_columns: HashSet<String> = node
317 .lookup_schema()
318 .fields()
319 .iter()
320 .map(|f| f.name().clone())
321 .collect();
322
323 let input_schema = node.inputs()[0].schema();
324 let stream_columns: HashSet<String> = input_schema
325 .fields()
326 .iter()
327 .map(|f| f.name().clone())
328 .collect();
329
330 let classifier = PredicateClassifier::new(
331 lookup_columns,
332 stream_columns,
333 node.lookup_alias(),
334 node.stream_alias(),
335 );
336
337 let caps = self.capabilities.get(node.lookup_table_name());
338 let pushdown_disabled = caps.is_none_or(|c| c.pushdown_mode == PlanPushdownMode::None);
339
340 let is_left_outer = node.join_type() == LookupJoinType::LeftOuter;
341
342 let mut pushdown = Vec::new();
343 let mut local = Vec::new();
344
345 let all_predicates = node
347 .pushdown_predicates()
348 .iter()
349 .chain(node.local_predicates().iter())
350 .chain(filter_predicates.iter())
351 .cloned();
352
353 for pred in all_predicates {
354 let class = classifier.classify(&pred);
355
356 let has_not_eq = contains_not_eq(&pred);
358
359 match class {
360 PredicateClass::LookupOnly => {
361 if is_left_outer || pushdown_disabled || has_not_eq {
363 local.push(pred);
364 } else {
365 pushdown.push(pred);
366 }
367 }
368 PredicateClass::StreamOnly
369 | PredicateClass::CrossReference
370 | PredicateClass::Constant => {
371 local.push(pred);
372 }
373 }
374 }
375
376 (pushdown, local)
377 }
378}
379
380fn contains_not_eq(expr: &Expr) -> bool {
382 match expr {
383 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
384 *op == DfOperator::NotEq || contains_not_eq(left) || contains_not_eq(right)
385 }
386 Expr::Not(inner) => contains_not_eq(inner),
387 _ => false,
388 }
389}
390
391impl OptimizerRule for PredicateSplitterRule {
392 fn name(&self) -> &'static str {
393 "predicate_splitter"
394 }
395
396 fn apply_order(&self) -> Option<ApplyOrder> {
397 Some(ApplyOrder::TopDown)
398 }
399
400 fn rewrite(
401 &self,
402 plan: LogicalPlan,
403 _config: &dyn OptimizerConfig,
404 ) -> Result<Transformed<LogicalPlan>> {
405 if let LogicalPlan::Filter(Filter {
407 predicate, input, ..
408 }) = &plan
409 {
410 if let LogicalPlan::Extension(ext) = input.as_ref() {
411 if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
412 let filter_preds = split_conjunction(predicate);
413 let (pushdown, local) = self.split_for_node(node, &filter_preds);
414
415 let inputs = node.inputs();
416 let rebuilt = LookupJoinNode::new(
417 inputs[0].clone(),
418 node.lookup_table_name().to_string(),
419 node.lookup_schema().clone(),
420 node.join_keys().to_vec(),
421 node.join_type(),
422 pushdown,
423 node.required_lookup_columns().clone(),
424 UserDefinedLogicalNodeCore::schema(node).clone(),
425 node.metadata().clone(),
426 )
427 .with_local_predicates(local)
428 .with_aliases(
429 node.lookup_alias().map(String::from),
430 node.stream_alias().map(String::from),
431 );
432
433 return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
434 node: Arc::new(rebuilt),
435 })));
436 }
437 }
438 }
439
440 if let LogicalPlan::Extension(ext) = &plan {
442 if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
443 if !node.pushdown_predicates().is_empty() || !node.local_predicates().is_empty() {
445 let (pushdown, local) = self.split_for_node(node, &[]);
446 let inputs = node.inputs();
447 let rebuilt = LookupJoinNode::new(
448 inputs[0].clone(),
449 node.lookup_table_name().to_string(),
450 node.lookup_schema().clone(),
451 node.join_keys().to_vec(),
452 node.join_type(),
453 pushdown,
454 node.required_lookup_columns().clone(),
455 UserDefinedLogicalNodeCore::schema(node).clone(),
456 node.metadata().clone(),
457 )
458 .with_local_predicates(local)
459 .with_aliases(
460 node.lookup_alias().map(String::from),
461 node.stream_alias().map(String::from),
462 );
463
464 return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
465 node: Arc::new(rebuilt),
466 })));
467 }
468 }
469 }
470
471 Ok(Transformed::no(plan))
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 use std::collections::HashSet;
480
481 use arrow::datatypes::{DataType, Field, Schema};
482 use datafusion::common::DFSchema;
483 use datafusion::logical_expr::col;
484 use datafusion::prelude::lit;
485
486 use crate::datafusion::lookup_join::{
487 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
488 };
489
490 fn lookup_cols() -> HashSet<String> {
491 HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()])
492 }
493
494 fn stream_cols() -> HashSet<String> {
495 HashSet::from([
496 "order_id".to_string(),
497 "customer_id".to_string(),
498 "amount".to_string(),
499 ])
500 }
501
502 fn classifier() -> PredicateClassifier {
503 PredicateClassifier::new(lookup_cols(), stream_cols(), None, None)
504 }
505
506 fn classifier_with_aliases() -> PredicateClassifier {
507 PredicateClassifier::new(lookup_cols(), stream_cols(), Some("c"), Some("o"))
508 }
509
510 #[test]
515 fn test_classify_lookup_only() {
516 let c = classifier();
517 let expr = col("region").eq(lit("US"));
518 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
519 }
520
521 #[test]
522 fn test_classify_stream_only() {
523 let c = classifier();
524 let expr = col("amount").gt(lit(100));
525 assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
526 }
527
528 #[test]
529 fn test_classify_cross_reference() {
530 let c = classifier();
531 let expr = col("amount").gt(col("id"));
533 assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
534 }
535
536 #[test]
537 fn test_classify_constant() {
538 let c = classifier();
539 let expr = lit(1).eq(lit(1));
540 assert_eq!(c.classify(&expr), PredicateClass::Constant);
541 }
542
543 #[test]
544 fn test_classify_qualified_lookup_c7() {
545 let c = classifier_with_aliases();
546 let expr = Expr::Column(datafusion::common::Column::new(Some::<&str>("c"), "name"))
548 .eq(lit("Alice"));
549 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
550 }
551
552 #[test]
553 fn test_classify_qualified_stream_c7() {
554 let c = classifier_with_aliases();
555 let expr =
556 Expr::Column(datafusion::common::Column::new(Some::<&str>("o"), "amount")).gt(lit(50));
557 assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
558 }
559
560 #[test]
561 fn test_classify_ambiguous_both_sides() {
562 let lookup = HashSet::from(["id".to_string()]);
564 let stream = HashSet::from(["id".to_string()]);
565 let c = PredicateClassifier::new(lookup, stream, None, None);
566 let expr = col("id").eq(lit(1));
567 assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
568 }
569
570 #[test]
571 fn test_classify_nested_function() {
572 let c = classifier();
573 let expr = Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
575 func: datafusion::functions::string::upper(),
576 args: vec![col("name")],
577 })
578 .eq(lit("ALICE"));
579 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
580 }
581
582 #[test]
583 fn test_classify_is_null() {
584 let c = classifier();
585 let expr = col("name").is_null();
586 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
587 }
588
589 #[test]
590 fn test_classify_between() {
591 let c = classifier();
592 let expr = Expr::Between(datafusion::logical_expr::expr::Between {
593 expr: Box::new(col("amount")),
594 negated: false,
595 low: Box::new(lit(10)),
596 high: Box::new(lit(100)),
597 });
598 assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
599 }
600
601 #[test]
602 fn test_classify_in_list() {
603 let c = classifier();
604 let expr = col("region").in_list(vec![lit("US"), lit("EU")], false);
605 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
606 }
607
608 #[test]
613 fn test_split_flat_conjunction() {
614 let expr = col("a")
615 .eq(lit(1))
616 .and(col("b").eq(lit(2)))
617 .and(col("c").eq(lit(3)));
618 let parts = split_conjunction(&expr);
619 assert_eq!(parts.len(), 3);
620 }
621
622 #[test]
623 fn test_split_nested_conjunction() {
624 let left = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
626 let right = col("c").eq(lit(3)).and(col("d").eq(lit(4)));
627 let expr = left.and(right);
628 let parts = split_conjunction(&expr);
629 assert_eq!(parts.len(), 4);
630 }
631
632 #[test]
633 fn test_split_single_predicate() {
634 let expr = col("a").eq(lit(1));
635 let parts = split_conjunction(&expr);
636 assert_eq!(parts.len(), 1);
637 }
638
639 #[test]
640 fn test_split_or_not_split() {
641 let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
643 let parts = split_conjunction(&expr);
644 assert_eq!(parts.len(), 1);
645 }
646
647 fn test_metadata() -> LookupTableMetadata {
652 LookupTableMetadata {
653 connector: "postgres-cdc".to_string(),
654 strategy: "replicated".to_string(),
655 pushdown_mode: "auto".to_string(),
656 primary_key: vec!["id".to_string()],
657 }
658 }
659
660 fn test_stream_schema() -> Arc<DFSchema> {
661 Arc::new(
662 DFSchema::try_from(Schema::new(vec![
663 Field::new("order_id", DataType::Int64, false),
664 Field::new("customer_id", DataType::Int64, false),
665 Field::new("amount", DataType::Float64, false),
666 ]))
667 .unwrap(),
668 )
669 }
670
671 fn test_lookup_schema() -> Arc<DFSchema> {
672 Arc::new(
673 DFSchema::try_from(Schema::new(vec![
674 Field::new("id", DataType::Int64, false),
675 Field::new("name", DataType::Utf8, true),
676 Field::new("region", DataType::Utf8, true),
677 ]))
678 .unwrap(),
679 )
680 }
681
682 fn test_output_schema() -> Arc<DFSchema> {
683 Arc::new(
684 DFSchema::try_from(Schema::new(vec![
685 Field::new("order_id", DataType::Int64, false),
686 Field::new("customer_id", DataType::Int64, false),
687 Field::new("amount", DataType::Float64, false),
688 Field::new("id", DataType::Int64, false),
689 Field::new("name", DataType::Utf8, true),
690 Field::new("region", DataType::Utf8, true),
691 ]))
692 .unwrap(),
693 )
694 }
695
696 fn make_lookup_node(join_type: LookupJoinType) -> LookupJoinNode {
697 let stream_schema = test_stream_schema();
698 let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
699 produce_one_row: false,
700 schema: stream_schema,
701 });
702
703 LookupJoinNode::new(
704 input,
705 "customers".to_string(),
706 test_lookup_schema(),
707 vec![JoinKeyPair {
708 stream_expr: col("customer_id"),
709 lookup_column: "id".to_string(),
710 }],
711 join_type,
712 vec![],
713 HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()]),
714 test_output_schema(),
715 test_metadata(),
716 )
717 }
718
719 fn make_filter_over_node(node: LookupJoinNode, predicate: Expr) -> LogicalPlan {
720 let ext = LogicalPlan::Extension(Extension {
721 node: Arc::new(node),
722 });
723 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(ext)).unwrap())
724 }
725
726 fn full_capabilities() -> SourceCapabilitiesRegistry {
727 let mut reg = SourceCapabilitiesRegistry::default();
728 reg.register(
729 "customers".to_string(),
730 PlanSourceCapabilities {
731 pushdown_mode: PlanPushdownMode::Full,
732 eq_columns: HashSet::from([
733 "id".to_string(),
734 "name".to_string(),
735 "region".to_string(),
736 ]),
737 range_columns: HashSet::new(),
738 in_columns: HashSet::new(),
739 supports_null_check: true,
740 },
741 );
742 reg
743 }
744
745 fn no_capabilities() -> SourceCapabilitiesRegistry {
746 SourceCapabilitiesRegistry::default()
747 }
748
749 #[test]
750 fn test_pushdown_inner_join_lookup_only() {
751 let node = make_lookup_node(LookupJoinType::Inner);
752 let filter_pred = col("region").eq(lit("US"));
753 let plan = make_filter_over_node(node, filter_pred);
754
755 let rule = PredicateSplitterRule::new(full_capabilities());
756 let result = rule
757 .rewrite(
758 plan,
759 &datafusion_optimizer::optimizer::OptimizerContext::new(),
760 )
761 .unwrap();
762
763 assert!(result.transformed);
764 if let LogicalPlan::Extension(ext) = &result.data {
765 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
766 assert_eq!(rebuilt.pushdown_predicates().len(), 1);
767 assert_eq!(rebuilt.local_predicates().len(), 0);
768 } else {
769 panic!("Expected Extension node");
770 }
771 }
772
773 #[test]
774 fn test_stream_predicate_stays_local() {
775 let node = make_lookup_node(LookupJoinType::Inner);
776 let filter_pred = col("amount").gt(lit(100));
777 let plan = make_filter_over_node(node, filter_pred);
778
779 let rule = PredicateSplitterRule::new(full_capabilities());
780 let result = rule
781 .rewrite(
782 plan,
783 &datafusion_optimizer::optimizer::OptimizerContext::new(),
784 )
785 .unwrap();
786
787 assert!(result.transformed);
788 if let LogicalPlan::Extension(ext) = &result.data {
789 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
790 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
791 assert_eq!(rebuilt.local_predicates().len(), 1);
792 } else {
793 panic!("Expected Extension node");
794 }
795 }
796
797 #[test]
798 fn test_cross_ref_stays_local() {
799 let node = make_lookup_node(LookupJoinType::Inner);
800 let filter_pred = col("amount").gt(col("id"));
802 let plan = make_filter_over_node(node, filter_pred);
803
804 let rule = PredicateSplitterRule::new(full_capabilities());
805 let result = rule
806 .rewrite(
807 plan,
808 &datafusion_optimizer::optimizer::OptimizerContext::new(),
809 )
810 .unwrap();
811
812 assert!(result.transformed);
813 if let LogicalPlan::Extension(ext) = &result.data {
814 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
815 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
816 assert_eq!(rebuilt.local_predicates().len(), 1);
817 } else {
818 panic!("Expected Extension node");
819 }
820 }
821
822 #[test]
823 fn test_pushdown_disabled_keeps_local() {
824 let node = make_lookup_node(LookupJoinType::Inner);
825 let filter_pred = col("region").eq(lit("US"));
826 let plan = make_filter_over_node(node, filter_pred);
827
828 let rule = PredicateSplitterRule::new(no_capabilities());
830 let result = rule
831 .rewrite(
832 plan,
833 &datafusion_optimizer::optimizer::OptimizerContext::new(),
834 )
835 .unwrap();
836
837 assert!(result.transformed);
838 if let LogicalPlan::Extension(ext) = &result.data {
839 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
840 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
841 assert_eq!(rebuilt.local_predicates().len(), 1);
842 } else {
843 panic!("Expected Extension node");
844 }
845 }
846
847 #[test]
848 fn test_left_join_h10_safety() {
849 let node = make_lookup_node(LookupJoinType::LeftOuter);
851 let filter_pred = col("region").eq(lit("US"));
852 let plan = make_filter_over_node(node, filter_pred);
853
854 let rule = PredicateSplitterRule::new(full_capabilities());
855 let result = rule
856 .rewrite(
857 plan,
858 &datafusion_optimizer::optimizer::OptimizerContext::new(),
859 )
860 .unwrap();
861
862 assert!(result.transformed);
863 if let LogicalPlan::Extension(ext) = &result.data {
864 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
865 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
867 assert_eq!(rebuilt.local_predicates().len(), 1);
868 } else {
869 panic!("Expected Extension node");
870 }
871 }
872
873 #[test]
874 fn test_no_filter_no_predicates_passthrough() {
875 let node = make_lookup_node(LookupJoinType::Inner);
876 let plan = LogicalPlan::Extension(Extension {
877 node: Arc::new(node),
878 });
879
880 let rule = PredicateSplitterRule::new(full_capabilities());
881 let result = rule
882 .rewrite(
883 plan,
884 &datafusion_optimizer::optimizer::OptimizerContext::new(),
885 )
886 .unwrap();
887
888 assert!(!result.transformed);
890 }
891
892 #[test]
893 fn test_mixed_conjunction_split() {
894 let node = make_lookup_node(LookupJoinType::Inner);
895 let filter_pred = col("region").eq(lit("US")).and(col("amount").gt(lit(100)));
897 let plan = make_filter_over_node(node, filter_pred);
898
899 let rule = PredicateSplitterRule::new(full_capabilities());
900 let result = rule
901 .rewrite(
902 plan,
903 &datafusion_optimizer::optimizer::OptimizerContext::new(),
904 )
905 .unwrap();
906
907 assert!(result.transformed);
908 if let LogicalPlan::Extension(ext) = &result.data {
909 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
910 assert_eq!(rebuilt.pushdown_predicates().len(), 1);
912 assert_eq!(rebuilt.local_predicates().len(), 1);
913 } else {
914 panic!("Expected Extension node");
915 }
916 }
917
918 #[test]
919 fn test_not_eq_stays_local() {
920 let node = make_lookup_node(LookupJoinType::Inner);
921 let filter_pred = col("region").not_eq(lit("US"));
922 let plan = make_filter_over_node(node, filter_pred);
923
924 let rule = PredicateSplitterRule::new(full_capabilities());
925 let result = rule
926 .rewrite(
927 plan,
928 &datafusion_optimizer::optimizer::OptimizerContext::new(),
929 )
930 .unwrap();
931
932 assert!(result.transformed);
933 if let LogicalPlan::Extension(ext) = &result.data {
934 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
935 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
937 assert_eq!(rebuilt.local_predicates().len(), 1);
938 } else {
939 panic!("Expected Extension node");
940 }
941 }
942
943 #[test]
944 fn test_source_capabilities_registry() {
945 let mut reg = SourceCapabilitiesRegistry::default();
946 assert!(reg.get("foo").is_none());
947
948 reg.register(
949 "foo".to_string(),
950 PlanSourceCapabilities {
951 pushdown_mode: PlanPushdownMode::Full,
952 ..Default::default()
953 },
954 );
955 assert_eq!(
956 reg.get("foo").unwrap().pushdown_mode,
957 PlanPushdownMode::Full
958 );
959 }
960
961 #[test]
962 fn test_plan_source_capabilities_default() {
963 let caps = PlanSourceCapabilities::default();
964 assert_eq!(caps.pushdown_mode, PlanPushdownMode::None);
965 assert!(caps.eq_columns.is_empty());
966 assert!(!caps.supports_null_check);
967 }
968}