1#[allow(clippy::disallowed_types)] use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23
24use datafusion::logical_expr::logical_plan::LogicalPlan;
25use datafusion::logical_expr::{
26 BinaryExpr, Expr, Extension, Filter, Operator as DfOperator, UserDefinedLogicalNodeCore,
27};
28use datafusion_common::tree_node::Transformed;
29use datafusion_common::Result;
30use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
31
32use crate::datafusion::lookup_join::{LookupJoinNode, LookupJoinType};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PredicateClass {
41 LookupOnly,
43 StreamOnly,
45 CrossReference,
47 Constant,
49}
50
51#[derive(Debug)]
57pub struct PredicateClassifier {
58 lookup_columns: HashSet<String>,
60 stream_columns: HashSet<String>,
62 lookup_qualified: HashSet<String>,
64 stream_qualified: HashSet<String>,
66}
67
68impl PredicateClassifier {
69 #[must_use]
75 pub fn new(
76 lookup_columns: HashSet<String>,
77 stream_columns: HashSet<String>,
78 lookup_alias: Option<&str>,
79 stream_alias: Option<&str>,
80 ) -> Self {
81 let mut lookup_qualified = HashSet::new();
82 let mut stream_qualified = HashSet::new();
83
84 if let Some(alias) = lookup_alias {
85 for col in &lookup_columns {
86 lookup_qualified.insert(format!("{alias}.{col}"));
87 }
88 }
89 if let Some(alias) = stream_alias {
90 for col in &stream_columns {
91 stream_qualified.insert(format!("{alias}.{col}"));
92 }
93 }
94
95 Self {
96 lookup_columns,
97 stream_columns,
98 lookup_qualified,
99 stream_qualified,
100 }
101 }
102
103 #[must_use]
105 pub fn classify(&self, expr: &Expr) -> PredicateClass {
106 let mut has_lookup = false;
107 let mut has_stream = false;
108 self.walk_columns(expr, &mut has_lookup, &mut has_stream);
109
110 match (has_lookup, has_stream) {
111 (true, false) => PredicateClass::LookupOnly,
112 (false, true) => PredicateClass::StreamOnly,
113 (true, true) => PredicateClass::CrossReference,
114 (false, false) => PredicateClass::Constant,
115 }
116 }
117
118 fn walk_columns(&self, expr: &Expr, has_lookup: &mut bool, has_stream: &mut bool) {
120 match expr {
121 Expr::Column(col) => {
122 if let Some(relation) = &col.relation {
124 let qualified = format!("{}.{}", relation, col.name);
125 if self.lookup_qualified.contains(&qualified) {
126 *has_lookup = true;
127 return;
128 }
129 if self.stream_qualified.contains(&qualified) {
130 *has_stream = true;
131 return;
132 }
133 }
134 if self.lookup_columns.contains(&col.name) {
136 *has_lookup = true;
137 }
138 if self.stream_columns.contains(&col.name) {
139 *has_stream = true;
140 }
141 }
142 Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
143 self.walk_columns(left, has_lookup, has_stream);
144 self.walk_columns(right, has_lookup, has_stream);
145 }
146 Expr::Not(inner)
147 | Expr::IsNull(inner)
148 | Expr::IsNotNull(inner)
149 | Expr::Negative(inner)
150 | Expr::Cast(datafusion::logical_expr::Cast { expr: inner, .. })
151 | Expr::TryCast(datafusion::logical_expr::TryCast { expr: inner, .. }) => {
152 self.walk_columns(inner, has_lookup, has_stream);
153 }
154 Expr::Between(between) => {
155 self.walk_columns(&between.expr, has_lookup, has_stream);
156 self.walk_columns(&between.low, has_lookup, has_stream);
157 self.walk_columns(&between.high, has_lookup, has_stream);
158 }
159 Expr::InList(in_list) => {
160 self.walk_columns(&in_list.expr, has_lookup, has_stream);
161 for item in &in_list.list {
162 self.walk_columns(item, has_lookup, has_stream);
163 }
164 }
165 Expr::ScalarFunction(func) => {
166 for arg in &func.args {
167 self.walk_columns(arg, has_lookup, has_stream);
168 }
169 }
170 Expr::Like(like) => {
171 self.walk_columns(&like.expr, has_lookup, has_stream);
172 self.walk_columns(&like.pattern, has_lookup, has_stream);
173 }
174 Expr::Case(case) => {
175 if let Some(operand) = &case.expr {
176 self.walk_columns(operand, has_lookup, has_stream);
177 }
178 for (when, then) in &case.when_then_expr {
179 self.walk_columns(when, has_lookup, has_stream);
180 self.walk_columns(then, has_lookup, has_stream);
181 }
182 if let Some(else_expr) = &case.else_expr {
183 self.walk_columns(else_expr, has_lookup, has_stream);
184 }
185 }
186 Expr::Literal(..) | Expr::Placeholder(_) => {}
188 _ => {
190 *has_lookup = true;
191 *has_stream = true;
192 }
193 }
194 }
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub enum PlanPushdownMode {
204 Full,
206 KeyOnly,
208 None,
210}
211
212#[derive(Debug, Clone)]
214pub struct PlanSourceCapabilities {
215 pub pushdown_mode: PlanPushdownMode,
217 pub eq_columns: HashSet<String>,
219 pub range_columns: HashSet<String>,
221 pub in_columns: HashSet<String>,
223 pub supports_null_check: bool,
225}
226
227impl Default for PlanSourceCapabilities {
228 fn default() -> Self {
229 Self {
230 pushdown_mode: PlanPushdownMode::None,
231 eq_columns: HashSet::new(),
232 range_columns: HashSet::new(),
233 in_columns: HashSet::new(),
234 supports_null_check: false,
235 }
236 }
237}
238
239#[derive(Debug, Default)]
241pub struct SourceCapabilitiesRegistry {
242 capabilities: HashMap<String, PlanSourceCapabilities>,
243}
244
245impl SourceCapabilitiesRegistry {
246 pub fn register(&mut self, table_name: String, caps: PlanSourceCapabilities) {
248 self.capabilities.insert(table_name, caps);
249 }
250
251 #[must_use]
253 pub fn get(&self, table_name: &str) -> Option<&PlanSourceCapabilities> {
254 self.capabilities.get(table_name)
255 }
256}
257
258#[must_use]
267pub fn split_conjunction(expr: &Expr) -> Vec<Expr> {
268 match expr {
269 Expr::BinaryExpr(BinaryExpr {
270 left,
271 op: DfOperator::And,
272 right,
273 }) => {
274 let mut parts = split_conjunction(left);
275 parts.extend(split_conjunction(right));
276 parts
277 }
278 other => vec![other.clone()],
279 }
280}
281
282#[derive(Debug)]
296pub struct PredicateSplitterRule {
297 capabilities: SourceCapabilitiesRegistry,
299}
300
301impl PredicateSplitterRule {
302 #[must_use]
304 pub fn new(capabilities: SourceCapabilitiesRegistry) -> Self {
305 Self { capabilities }
306 }
307
308 fn split_for_node(
312 &self,
313 node: &LookupJoinNode,
314 filter_predicates: &[Expr],
315 ) -> (Vec<Expr>, Vec<Expr>) {
316 let lookup_columns: HashSet<String> = node
318 .lookup_schema()
319 .fields()
320 .iter()
321 .map(|f| f.name().clone())
322 .collect();
323
324 let input_schema = node.inputs()[0].schema();
325 let stream_columns: HashSet<String> = input_schema
326 .fields()
327 .iter()
328 .map(|f| f.name().clone())
329 .collect();
330
331 let classifier = PredicateClassifier::new(
332 lookup_columns,
333 stream_columns,
334 node.lookup_alias(),
335 node.stream_alias(),
336 );
337
338 let caps = self.capabilities.get(node.lookup_table_name());
339 let pushdown_disabled = caps.is_none_or(|c| c.pushdown_mode == PlanPushdownMode::None);
340
341 let is_left_outer = node.join_type() == LookupJoinType::LeftOuter;
342
343 let mut pushdown = Vec::new();
344 let mut local = Vec::new();
345
346 let all_predicates = node
348 .pushdown_predicates()
349 .iter()
350 .chain(node.local_predicates().iter())
351 .chain(filter_predicates.iter())
352 .cloned();
353
354 for pred in all_predicates {
355 let class = classifier.classify(&pred);
356
357 let has_not_eq = contains_not_eq(&pred);
359
360 match class {
361 PredicateClass::LookupOnly => {
362 if is_left_outer || pushdown_disabled || has_not_eq {
364 local.push(pred);
365 } else {
366 pushdown.push(pred);
367 }
368 }
369 PredicateClass::StreamOnly
370 | PredicateClass::CrossReference
371 | PredicateClass::Constant => {
372 local.push(pred);
373 }
374 }
375 }
376
377 (pushdown, local)
378 }
379}
380
381fn contains_not_eq(expr: &Expr) -> bool {
383 match expr {
384 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
385 *op == DfOperator::NotEq || contains_not_eq(left) || contains_not_eq(right)
386 }
387 Expr::Not(inner) => contains_not_eq(inner),
388 _ => false,
389 }
390}
391
392impl OptimizerRule for PredicateSplitterRule {
393 fn name(&self) -> &'static str {
394 "predicate_splitter"
395 }
396
397 fn apply_order(&self) -> Option<ApplyOrder> {
398 Some(ApplyOrder::TopDown)
399 }
400
401 fn rewrite(
402 &self,
403 plan: LogicalPlan,
404 _config: &dyn OptimizerConfig,
405 ) -> Result<Transformed<LogicalPlan>> {
406 if let LogicalPlan::Filter(Filter {
408 predicate, input, ..
409 }) = &plan
410 {
411 if let LogicalPlan::Extension(ext) = input.as_ref() {
412 if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
413 let filter_preds = split_conjunction(predicate);
414 let (pushdown, local) = self.split_for_node(node, &filter_preds);
415
416 let inputs = node.inputs();
417 let rebuilt = LookupJoinNode::new(
418 inputs[0].clone(),
419 node.lookup_table_name().to_string(),
420 node.lookup_schema().clone(),
421 node.join_keys().to_vec(),
422 node.join_type(),
423 pushdown,
424 node.required_lookup_columns().clone(),
425 UserDefinedLogicalNodeCore::schema(node).clone(),
426 node.metadata().clone(),
427 )
428 .with_local_predicates(local)
429 .with_aliases(
430 node.lookup_alias().map(String::from),
431 node.stream_alias().map(String::from),
432 );
433
434 return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
435 node: Arc::new(rebuilt),
436 })));
437 }
438 }
439 }
440
441 if let LogicalPlan::Extension(ext) = &plan {
443 if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
444 if !node.pushdown_predicates().is_empty() || !node.local_predicates().is_empty() {
446 let (pushdown, local) = self.split_for_node(node, &[]);
447 let inputs = node.inputs();
448 let rebuilt = LookupJoinNode::new(
449 inputs[0].clone(),
450 node.lookup_table_name().to_string(),
451 node.lookup_schema().clone(),
452 node.join_keys().to_vec(),
453 node.join_type(),
454 pushdown,
455 node.required_lookup_columns().clone(),
456 UserDefinedLogicalNodeCore::schema(node).clone(),
457 node.metadata().clone(),
458 )
459 .with_local_predicates(local)
460 .with_aliases(
461 node.lookup_alias().map(String::from),
462 node.stream_alias().map(String::from),
463 );
464
465 return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
466 node: Arc::new(rebuilt),
467 })));
468 }
469 }
470 }
471
472 Ok(Transformed::no(plan))
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[allow(clippy::disallowed_types)] use std::collections::HashSet;
482
483 use arrow::datatypes::{DataType, Field, Schema};
484 use datafusion::common::DFSchema;
485 use datafusion::logical_expr::col;
486 use datafusion::prelude::lit;
487
488 use crate::datafusion::lookup_join::{
489 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
490 };
491
492 fn lookup_cols() -> HashSet<String> {
493 HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()])
494 }
495
496 fn stream_cols() -> HashSet<String> {
497 HashSet::from([
498 "order_id".to_string(),
499 "customer_id".to_string(),
500 "amount".to_string(),
501 ])
502 }
503
504 fn classifier() -> PredicateClassifier {
505 PredicateClassifier::new(lookup_cols(), stream_cols(), None, None)
506 }
507
508 fn classifier_with_aliases() -> PredicateClassifier {
509 PredicateClassifier::new(lookup_cols(), stream_cols(), Some("c"), Some("o"))
510 }
511
512 #[test]
517 fn test_classify_lookup_only() {
518 let c = classifier();
519 let expr = col("region").eq(lit("US"));
520 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
521 }
522
523 #[test]
524 fn test_classify_stream_only() {
525 let c = classifier();
526 let expr = col("amount").gt(lit(100));
527 assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
528 }
529
530 #[test]
531 fn test_classify_cross_reference() {
532 let c = classifier();
533 let expr = col("amount").gt(col("id"));
535 assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
536 }
537
538 #[test]
539 fn test_classify_constant() {
540 let c = classifier();
541 let expr = lit(1).eq(lit(1));
542 assert_eq!(c.classify(&expr), PredicateClass::Constant);
543 }
544
545 #[test]
546 fn test_classify_qualified_lookup_c7() {
547 let c = classifier_with_aliases();
548 let expr = Expr::Column(datafusion::common::Column::new(Some::<&str>("c"), "name"))
550 .eq(lit("Alice"));
551 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
552 }
553
554 #[test]
555 fn test_classify_qualified_stream_c7() {
556 let c = classifier_with_aliases();
557 let expr =
558 Expr::Column(datafusion::common::Column::new(Some::<&str>("o"), "amount")).gt(lit(50));
559 assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
560 }
561
562 #[test]
563 fn test_classify_ambiguous_both_sides() {
564 let lookup = HashSet::from(["id".to_string()]);
566 let stream = HashSet::from(["id".to_string()]);
567 let c = PredicateClassifier::new(lookup, stream, None, None);
568 let expr = col("id").eq(lit(1));
569 assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
570 }
571
572 #[test]
573 fn test_classify_nested_function() {
574 let c = classifier();
575 let expr = Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
577 func: datafusion::functions::string::upper(),
578 args: vec![col("name")],
579 })
580 .eq(lit("ALICE"));
581 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
582 }
583
584 #[test]
585 fn test_classify_is_null() {
586 let c = classifier();
587 let expr = col("name").is_null();
588 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
589 }
590
591 #[test]
592 fn test_classify_between() {
593 let c = classifier();
594 let expr = Expr::Between(datafusion::logical_expr::expr::Between {
595 expr: Box::new(col("amount")),
596 negated: false,
597 low: Box::new(lit(10)),
598 high: Box::new(lit(100)),
599 });
600 assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
601 }
602
603 #[test]
604 fn test_classify_in_list() {
605 let c = classifier();
606 let expr = col("region").in_list(vec![lit("US"), lit("EU")], false);
607 assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
608 }
609
610 #[test]
615 fn test_split_flat_conjunction() {
616 let expr = col("a")
617 .eq(lit(1))
618 .and(col("b").eq(lit(2)))
619 .and(col("c").eq(lit(3)));
620 let parts = split_conjunction(&expr);
621 assert_eq!(parts.len(), 3);
622 }
623
624 #[test]
625 fn test_split_nested_conjunction() {
626 let left = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
628 let right = col("c").eq(lit(3)).and(col("d").eq(lit(4)));
629 let expr = left.and(right);
630 let parts = split_conjunction(&expr);
631 assert_eq!(parts.len(), 4);
632 }
633
634 #[test]
635 fn test_split_single_predicate() {
636 let expr = col("a").eq(lit(1));
637 let parts = split_conjunction(&expr);
638 assert_eq!(parts.len(), 1);
639 }
640
641 #[test]
642 fn test_split_or_not_split() {
643 let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
645 let parts = split_conjunction(&expr);
646 assert_eq!(parts.len(), 1);
647 }
648
649 fn test_metadata() -> LookupTableMetadata {
654 LookupTableMetadata {
655 connector: "postgres-cdc".to_string(),
656 strategy: "replicated".to_string(),
657 pushdown_mode: "auto".to_string(),
658 primary_key: vec!["id".to_string()],
659 }
660 }
661
662 fn test_stream_schema() -> Arc<DFSchema> {
663 Arc::new(
664 DFSchema::try_from(Schema::new(vec![
665 Field::new("order_id", DataType::Int64, false),
666 Field::new("customer_id", DataType::Int64, false),
667 Field::new("amount", DataType::Float64, false),
668 ]))
669 .unwrap(),
670 )
671 }
672
673 fn test_lookup_schema() -> Arc<DFSchema> {
674 Arc::new(
675 DFSchema::try_from(Schema::new(vec![
676 Field::new("id", DataType::Int64, false),
677 Field::new("name", DataType::Utf8, true),
678 Field::new("region", DataType::Utf8, true),
679 ]))
680 .unwrap(),
681 )
682 }
683
684 fn test_output_schema() -> Arc<DFSchema> {
685 Arc::new(
686 DFSchema::try_from(Schema::new(vec![
687 Field::new("order_id", DataType::Int64, false),
688 Field::new("customer_id", DataType::Int64, false),
689 Field::new("amount", DataType::Float64, false),
690 Field::new("id", DataType::Int64, false),
691 Field::new("name", DataType::Utf8, true),
692 Field::new("region", DataType::Utf8, true),
693 ]))
694 .unwrap(),
695 )
696 }
697
698 fn make_lookup_node(join_type: LookupJoinType) -> LookupJoinNode {
699 let stream_schema = test_stream_schema();
700 let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
701 produce_one_row: false,
702 schema: stream_schema,
703 });
704
705 LookupJoinNode::new(
706 input,
707 "customers".to_string(),
708 test_lookup_schema(),
709 vec![JoinKeyPair {
710 stream_expr: col("customer_id"),
711 lookup_column: "id".to_string(),
712 }],
713 join_type,
714 vec![],
715 HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()]),
716 test_output_schema(),
717 test_metadata(),
718 )
719 }
720
721 fn make_filter_over_node(node: LookupJoinNode, predicate: Expr) -> LogicalPlan {
722 let ext = LogicalPlan::Extension(Extension {
723 node: Arc::new(node),
724 });
725 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(ext)).unwrap())
726 }
727
728 fn full_capabilities() -> SourceCapabilitiesRegistry {
729 let mut reg = SourceCapabilitiesRegistry::default();
730 reg.register(
731 "customers".to_string(),
732 PlanSourceCapabilities {
733 pushdown_mode: PlanPushdownMode::Full,
734 eq_columns: HashSet::from([
735 "id".to_string(),
736 "name".to_string(),
737 "region".to_string(),
738 ]),
739 range_columns: HashSet::new(),
740 in_columns: HashSet::new(),
741 supports_null_check: true,
742 },
743 );
744 reg
745 }
746
747 fn no_capabilities() -> SourceCapabilitiesRegistry {
748 SourceCapabilitiesRegistry::default()
749 }
750
751 #[test]
752 fn test_pushdown_inner_join_lookup_only() {
753 let node = make_lookup_node(LookupJoinType::Inner);
754 let filter_pred = col("region").eq(lit("US"));
755 let plan = make_filter_over_node(node, filter_pred);
756
757 let rule = PredicateSplitterRule::new(full_capabilities());
758 let result = rule
759 .rewrite(
760 plan,
761 &datafusion_optimizer::optimizer::OptimizerContext::new(),
762 )
763 .unwrap();
764
765 assert!(result.transformed);
766 if let LogicalPlan::Extension(ext) = &result.data {
767 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
768 assert_eq!(rebuilt.pushdown_predicates().len(), 1);
769 assert_eq!(rebuilt.local_predicates().len(), 0);
770 } else {
771 panic!("Expected Extension node");
772 }
773 }
774
775 #[test]
776 fn test_stream_predicate_stays_local() {
777 let node = make_lookup_node(LookupJoinType::Inner);
778 let filter_pred = col("amount").gt(lit(100));
779 let plan = make_filter_over_node(node, filter_pred);
780
781 let rule = PredicateSplitterRule::new(full_capabilities());
782 let result = rule
783 .rewrite(
784 plan,
785 &datafusion_optimizer::optimizer::OptimizerContext::new(),
786 )
787 .unwrap();
788
789 assert!(result.transformed);
790 if let LogicalPlan::Extension(ext) = &result.data {
791 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
792 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
793 assert_eq!(rebuilt.local_predicates().len(), 1);
794 } else {
795 panic!("Expected Extension node");
796 }
797 }
798
799 #[test]
800 fn test_cross_ref_stays_local() {
801 let node = make_lookup_node(LookupJoinType::Inner);
802 let filter_pred = col("amount").gt(col("id"));
804 let plan = make_filter_over_node(node, filter_pred);
805
806 let rule = PredicateSplitterRule::new(full_capabilities());
807 let result = rule
808 .rewrite(
809 plan,
810 &datafusion_optimizer::optimizer::OptimizerContext::new(),
811 )
812 .unwrap();
813
814 assert!(result.transformed);
815 if let LogicalPlan::Extension(ext) = &result.data {
816 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
817 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
818 assert_eq!(rebuilt.local_predicates().len(), 1);
819 } else {
820 panic!("Expected Extension node");
821 }
822 }
823
824 #[test]
825 fn test_pushdown_disabled_keeps_local() {
826 let node = make_lookup_node(LookupJoinType::Inner);
827 let filter_pred = col("region").eq(lit("US"));
828 let plan = make_filter_over_node(node, filter_pred);
829
830 let rule = PredicateSplitterRule::new(no_capabilities());
832 let result = rule
833 .rewrite(
834 plan,
835 &datafusion_optimizer::optimizer::OptimizerContext::new(),
836 )
837 .unwrap();
838
839 assert!(result.transformed);
840 if let LogicalPlan::Extension(ext) = &result.data {
841 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
842 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
843 assert_eq!(rebuilt.local_predicates().len(), 1);
844 } else {
845 panic!("Expected Extension node");
846 }
847 }
848
849 #[test]
850 fn test_left_join_h10_safety() {
851 let node = make_lookup_node(LookupJoinType::LeftOuter);
853 let filter_pred = col("region").eq(lit("US"));
854 let plan = make_filter_over_node(node, filter_pred);
855
856 let rule = PredicateSplitterRule::new(full_capabilities());
857 let result = rule
858 .rewrite(
859 plan,
860 &datafusion_optimizer::optimizer::OptimizerContext::new(),
861 )
862 .unwrap();
863
864 assert!(result.transformed);
865 if let LogicalPlan::Extension(ext) = &result.data {
866 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
867 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
869 assert_eq!(rebuilt.local_predicates().len(), 1);
870 } else {
871 panic!("Expected Extension node");
872 }
873 }
874
875 #[test]
876 fn test_no_filter_no_predicates_passthrough() {
877 let node = make_lookup_node(LookupJoinType::Inner);
878 let plan = LogicalPlan::Extension(Extension {
879 node: Arc::new(node),
880 });
881
882 let rule = PredicateSplitterRule::new(full_capabilities());
883 let result = rule
884 .rewrite(
885 plan,
886 &datafusion_optimizer::optimizer::OptimizerContext::new(),
887 )
888 .unwrap();
889
890 assert!(!result.transformed);
892 }
893
894 #[test]
895 fn test_mixed_conjunction_split() {
896 let node = make_lookup_node(LookupJoinType::Inner);
897 let filter_pred = col("region").eq(lit("US")).and(col("amount").gt(lit(100)));
899 let plan = make_filter_over_node(node, filter_pred);
900
901 let rule = PredicateSplitterRule::new(full_capabilities());
902 let result = rule
903 .rewrite(
904 plan,
905 &datafusion_optimizer::optimizer::OptimizerContext::new(),
906 )
907 .unwrap();
908
909 assert!(result.transformed);
910 if let LogicalPlan::Extension(ext) = &result.data {
911 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
912 assert_eq!(rebuilt.pushdown_predicates().len(), 1);
914 assert_eq!(rebuilt.local_predicates().len(), 1);
915 } else {
916 panic!("Expected Extension node");
917 }
918 }
919
920 #[test]
921 fn test_not_eq_stays_local() {
922 let node = make_lookup_node(LookupJoinType::Inner);
923 let filter_pred = col("region").not_eq(lit("US"));
924 let plan = make_filter_over_node(node, filter_pred);
925
926 let rule = PredicateSplitterRule::new(full_capabilities());
927 let result = rule
928 .rewrite(
929 plan,
930 &datafusion_optimizer::optimizer::OptimizerContext::new(),
931 )
932 .unwrap();
933
934 assert!(result.transformed);
935 if let LogicalPlan::Extension(ext) = &result.data {
936 let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
937 assert_eq!(rebuilt.pushdown_predicates().len(), 0);
939 assert_eq!(rebuilt.local_predicates().len(), 1);
940 } else {
941 panic!("Expected Extension node");
942 }
943 }
944
945 #[test]
946 fn test_source_capabilities_registry() {
947 let mut reg = SourceCapabilitiesRegistry::default();
948 assert!(reg.get("foo").is_none());
949
950 reg.register(
951 "foo".to_string(),
952 PlanSourceCapabilities {
953 pushdown_mode: PlanPushdownMode::Full,
954 ..Default::default()
955 },
956 );
957 assert_eq!(
958 reg.get("foo").unwrap().pushdown_mode,
959 PlanPushdownMode::Full
960 );
961 }
962
963 #[test]
964 fn test_plan_source_capabilities_default() {
965 let caps = PlanSourceCapabilities::default();
966 assert_eq!(caps.pushdown_mode, PlanPushdownMode::None);
967 assert!(caps.eq_columns.is_empty());
968 assert!(!caps.supports_null_check);
969 }
970}