Skip to main content

laminar_sql/planner/
predicate_split.rs

1//! Predicate splitting and pushdown for lookup joins.
2//!
3//! This module classifies WHERE/ON predicates in lookup join queries and
4//! splits them into pushdown vs local evaluation categories. It implements
5//! a DataFusion optimizer rule (`PredicateSplitterRule`) that absorbs
6//! filter nodes above `LookupJoinNode` and assigns each predicate to the
7//! correct execution site.
8//!
9//! ## Key Safety Rules
10//!
11//! - **H10 (LEFT JOIN safety):** WHERE-clause predicates on lookup-only
12//!   columns above a `LeftOuter` join must NOT be pushed down — doing so
13//!   changes the semantics by filtering out NULL-extended rows.
14//! - **C7 (qualified columns):** When aliases are present, `col.relation`
15//!   is checked first for unambiguous resolution before falling back to
16//!   unqualified column name matching.
17//! - **`NotEq`** predicates are classified normally but are never pushed
18//!   down (they cannot use equality indexes on the source).
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use datafusion::logical_expr::logical_plan::LogicalPlan;
24use datafusion::logical_expr::{
25    BinaryExpr, Expr, Extension, Filter, Operator as DfOperator, UserDefinedLogicalNodeCore,
26};
27use datafusion_common::tree_node::Transformed;
28use datafusion_common::Result;
29use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
30
31use crate::datafusion::lookup_join::{LookupJoinNode, LookupJoinType};
32
33// ---------------------------------------------------------------------------
34// Predicate Classification
35// ---------------------------------------------------------------------------
36
37/// Classification of a predicate based on which side(s) it references.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum PredicateClass {
40    /// References only lookup table columns — candidate for pushdown.
41    LookupOnly,
42    /// References only stream columns — evaluate locally.
43    StreamOnly,
44    /// References columns from both sides — evaluate locally.
45    CrossReference,
46    /// References no columns (constant expression) — evaluate locally.
47    Constant,
48}
49
50/// Classifies predicates based on column membership.
51///
52/// Uses both unqualified column names and qualified `"alias.col"` names
53/// for resolution (audit C7). When a column has a relation qualifier,
54/// the qualified form is checked first.
55#[derive(Debug)]
56pub struct PredicateClassifier {
57    /// Unqualified lookup column names.
58    lookup_columns: HashSet<String>,
59    /// Unqualified stream column names.
60    stream_columns: HashSet<String>,
61    /// Qualified lookup names: `"alias.col"` or `"table.col"`.
62    lookup_qualified: HashSet<String>,
63    /// Qualified stream names: `"alias.col"` or `"table.col"`.
64    stream_qualified: HashSet<String>,
65}
66
67impl PredicateClassifier {
68    /// Creates a new classifier from column sets.
69    ///
70    /// `lookup_alias` / `stream_alias` are the SQL aliases (e.g., `c` for
71    /// `customers c`). When provided, qualified lookups like `c.name` can
72    /// be resolved unambiguously.
73    #[must_use]
74    pub fn new(
75        lookup_columns: HashSet<String>,
76        stream_columns: HashSet<String>,
77        lookup_alias: Option<&str>,
78        stream_alias: Option<&str>,
79    ) -> Self {
80        let mut lookup_qualified = HashSet::new();
81        let mut stream_qualified = HashSet::new();
82
83        if let Some(alias) = lookup_alias {
84            for col in &lookup_columns {
85                lookup_qualified.insert(format!("{alias}.{col}"));
86            }
87        }
88        if let Some(alias) = stream_alias {
89            for col in &stream_columns {
90                stream_qualified.insert(format!("{alias}.{col}"));
91            }
92        }
93
94        Self {
95            lookup_columns,
96            stream_columns,
97            lookup_qualified,
98            stream_qualified,
99        }
100    }
101
102    /// Classify a predicate expression.
103    #[must_use]
104    pub fn classify(&self, expr: &Expr) -> PredicateClass {
105        let mut has_lookup = false;
106        let mut has_stream = false;
107        self.walk_columns(expr, &mut has_lookup, &mut has_stream);
108
109        match (has_lookup, has_stream) {
110            (true, false) => PredicateClass::LookupOnly,
111            (false, true) => PredicateClass::StreamOnly,
112            (true, true) => PredicateClass::CrossReference,
113            (false, false) => PredicateClass::Constant,
114        }
115    }
116
117    /// Recursively walk an expression to find column references.
118    fn walk_columns(&self, expr: &Expr, has_lookup: &mut bool, has_stream: &mut bool) {
119        match expr {
120            Expr::Column(col) => {
121                // C7: check qualified form first
122                if let Some(relation) = &col.relation {
123                    let qualified = format!("{}.{}", relation, col.name);
124                    if self.lookup_qualified.contains(&qualified) {
125                        *has_lookup = true;
126                        return;
127                    }
128                    if self.stream_qualified.contains(&qualified) {
129                        *has_stream = true;
130                        return;
131                    }
132                }
133                // Fall back to unqualified
134                if self.lookup_columns.contains(&col.name) {
135                    *has_lookup = true;
136                }
137                if self.stream_columns.contains(&col.name) {
138                    *has_stream = true;
139                }
140            }
141            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
142                self.walk_columns(left, has_lookup, has_stream);
143                self.walk_columns(right, has_lookup, has_stream);
144            }
145            Expr::Not(inner)
146            | Expr::IsNull(inner)
147            | Expr::IsNotNull(inner)
148            | Expr::Negative(inner)
149            | Expr::Cast(datafusion::logical_expr::Cast { expr: inner, .. })
150            | Expr::TryCast(datafusion::logical_expr::TryCast { expr: inner, .. }) => {
151                self.walk_columns(inner, has_lookup, has_stream);
152            }
153            Expr::Between(between) => {
154                self.walk_columns(&between.expr, has_lookup, has_stream);
155                self.walk_columns(&between.low, has_lookup, has_stream);
156                self.walk_columns(&between.high, has_lookup, has_stream);
157            }
158            Expr::InList(in_list) => {
159                self.walk_columns(&in_list.expr, has_lookup, has_stream);
160                for item in &in_list.list {
161                    self.walk_columns(item, has_lookup, has_stream);
162                }
163            }
164            Expr::ScalarFunction(func) => {
165                for arg in &func.args {
166                    self.walk_columns(arg, has_lookup, has_stream);
167                }
168            }
169            Expr::Like(like) => {
170                self.walk_columns(&like.expr, has_lookup, has_stream);
171                self.walk_columns(&like.pattern, has_lookup, has_stream);
172            }
173            Expr::Case(case) => {
174                if let Some(operand) = &case.expr {
175                    self.walk_columns(operand, has_lookup, has_stream);
176                }
177                for (when, then) in &case.when_then_expr {
178                    self.walk_columns(when, has_lookup, has_stream);
179                    self.walk_columns(then, has_lookup, has_stream);
180                }
181                if let Some(else_expr) = &case.else_expr {
182                    self.walk_columns(else_expr, has_lookup, has_stream);
183                }
184            }
185            // Literals, placeholders — no columns
186            Expr::Literal(..) | Expr::Placeholder(_) => {}
187            // Catch-all: conservative — mark both sides
188            _ => {
189                *has_lookup = true;
190                *has_stream = true;
191            }
192        }
193    }
194}
195
196// ---------------------------------------------------------------------------
197// Source Capabilities
198// ---------------------------------------------------------------------------
199
200/// Mode describing how far predicates can be pushed to a source.
201#[derive(Debug, Clone, Copy, PartialEq, Eq)]
202pub enum PlanPushdownMode {
203    /// Full predicate pushdown (eq, range, in, null checks).
204    Full,
205    /// Only key equality predicates.
206    KeyOnly,
207    /// No pushdown at all.
208    None,
209}
210
211/// Describes a source's pushdown capabilities for the optimizer.
212#[derive(Debug, Clone)]
213pub struct PlanSourceCapabilities {
214    /// Overall pushdown mode.
215    pub pushdown_mode: PlanPushdownMode,
216    /// Columns that support equality pushdown.
217    pub eq_columns: HashSet<String>,
218    /// Columns that support range pushdown.
219    pub range_columns: HashSet<String>,
220    /// Columns that support IN-list pushdown.
221    pub in_columns: HashSet<String>,
222    /// Whether the source supports IS NULL / IS NOT NULL checks.
223    pub supports_null_check: bool,
224}
225
226impl Default for PlanSourceCapabilities {
227    fn default() -> Self {
228        Self {
229            pushdown_mode: PlanPushdownMode::None,
230            eq_columns: HashSet::new(),
231            range_columns: HashSet::new(),
232            in_columns: HashSet::new(),
233            supports_null_check: false,
234        }
235    }
236}
237
238/// Registry mapping lookup table names to their source capabilities.
239#[derive(Debug, Default)]
240pub struct SourceCapabilitiesRegistry {
241    capabilities: HashMap<String, PlanSourceCapabilities>,
242}
243
244impl SourceCapabilitiesRegistry {
245    /// Register capabilities for a lookup table.
246    pub fn register(&mut self, table_name: String, caps: PlanSourceCapabilities) {
247        self.capabilities.insert(table_name, caps);
248    }
249
250    /// Get capabilities for a lookup table.
251    #[must_use]
252    pub fn get(&self, table_name: &str) -> Option<&PlanSourceCapabilities> {
253        self.capabilities.get(table_name)
254    }
255}
256
257// ---------------------------------------------------------------------------
258// Conjunction Splitting
259// ---------------------------------------------------------------------------
260
261/// Splits a conjunction (AND chain) into individual predicates.
262///
263/// `A AND B AND C` → `[A, B, C]`.
264/// OR expressions and non-AND binary expressions are kept as single items.
265#[must_use]
266pub fn split_conjunction(expr: &Expr) -> Vec<Expr> {
267    match expr {
268        Expr::BinaryExpr(BinaryExpr {
269            left,
270            op: DfOperator::And,
271            right,
272        }) => {
273            let mut parts = split_conjunction(left);
274            parts.extend(split_conjunction(right));
275            parts
276        }
277        other => vec![other.clone()],
278    }
279}
280
281// ---------------------------------------------------------------------------
282// Optimizer Rule
283// ---------------------------------------------------------------------------
284
285/// DataFusion optimizer rule that splits predicates for lookup joins.
286///
287/// Runs `TopDown` to catch `Filter` nodes above `LookupJoinNode` first.
288///
289/// Two cases:
290/// 1. **Filter above LookupJoinNode** — absorb the filter, classify
291///    each conjunct, and assign to pushdown or local.
292/// 2. **Direct LookupJoinNode** — re-classify existing pushdown predicates
293///    (e.g., after a previous pass added them).
294#[derive(Debug)]
295pub struct PredicateSplitterRule {
296    /// Per-table source capabilities.
297    capabilities: SourceCapabilitiesRegistry,
298}
299
300impl PredicateSplitterRule {
301    /// Creates a new rule with the given capabilities registry.
302    #[must_use]
303    pub fn new(capabilities: SourceCapabilitiesRegistry) -> Self {
304        Self { capabilities }
305    }
306
307    /// Split predicates for a `LookupJoinNode`, given a list of predicates
308    /// that come from an absorbed `Filter` (if any) plus the node's
309    /// existing predicates.
310    fn split_for_node(
311        &self,
312        node: &LookupJoinNode,
313        filter_predicates: &[Expr],
314    ) -> (Vec<Expr>, Vec<Expr>) {
315        // Build column sets from schemas
316        let lookup_columns: HashSet<String> = node
317            .lookup_schema()
318            .fields()
319            .iter()
320            .map(|f| f.name().clone())
321            .collect();
322
323        let input_schema = node.inputs()[0].schema();
324        let stream_columns: HashSet<String> = input_schema
325            .fields()
326            .iter()
327            .map(|f| f.name().clone())
328            .collect();
329
330        let classifier = PredicateClassifier::new(
331            lookup_columns,
332            stream_columns,
333            node.lookup_alias(),
334            node.stream_alias(),
335        );
336
337        let caps = self.capabilities.get(node.lookup_table_name());
338        let pushdown_disabled = caps.is_none_or(|c| c.pushdown_mode == PlanPushdownMode::None);
339
340        let is_left_outer = node.join_type() == LookupJoinType::LeftOuter;
341
342        let mut pushdown = Vec::new();
343        let mut local = Vec::new();
344
345        // Include existing predicates from the node
346        let all_predicates = node
347            .pushdown_predicates()
348            .iter()
349            .chain(node.local_predicates().iter())
350            .chain(filter_predicates.iter())
351            .cloned();
352
353        for pred in all_predicates {
354            let class = classifier.classify(&pred);
355
356            // NotEq predicates never push down
357            let has_not_eq = contains_not_eq(&pred);
358
359            match class {
360                PredicateClass::LookupOnly => {
361                    // H10: LEFT OUTER WHERE-clause lookup-only preds stay local
362                    if is_left_outer || pushdown_disabled || has_not_eq {
363                        local.push(pred);
364                    } else {
365                        pushdown.push(pred);
366                    }
367                }
368                PredicateClass::StreamOnly
369                | PredicateClass::CrossReference
370                | PredicateClass::Constant => {
371                    local.push(pred);
372                }
373            }
374        }
375
376        (pushdown, local)
377    }
378}
379
380/// Check if an expression contains a `NotEq` operator.
381fn contains_not_eq(expr: &Expr) -> bool {
382    match expr {
383        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
384            *op == DfOperator::NotEq || contains_not_eq(left) || contains_not_eq(right)
385        }
386        Expr::Not(inner) => contains_not_eq(inner),
387        _ => false,
388    }
389}
390
391impl OptimizerRule for PredicateSplitterRule {
392    fn name(&self) -> &'static str {
393        "predicate_splitter"
394    }
395
396    fn apply_order(&self) -> Option<ApplyOrder> {
397        Some(ApplyOrder::TopDown)
398    }
399
400    fn rewrite(
401        &self,
402        plan: LogicalPlan,
403        _config: &dyn OptimizerConfig,
404    ) -> Result<Transformed<LogicalPlan>> {
405        // Case 1: Filter above a LookupJoinNode
406        if let LogicalPlan::Filter(Filter {
407            predicate, input, ..
408        }) = &plan
409        {
410            if let LogicalPlan::Extension(ext) = input.as_ref() {
411                if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
412                    let filter_preds = split_conjunction(predicate);
413                    let (pushdown, local) = self.split_for_node(node, &filter_preds);
414
415                    let inputs = node.inputs();
416                    let rebuilt = LookupJoinNode::new(
417                        inputs[0].clone(),
418                        node.lookup_table_name().to_string(),
419                        node.lookup_schema().clone(),
420                        node.join_keys().to_vec(),
421                        node.join_type(),
422                        pushdown,
423                        node.required_lookup_columns().clone(),
424                        UserDefinedLogicalNodeCore::schema(node).clone(),
425                        node.metadata().clone(),
426                    )
427                    .with_local_predicates(local)
428                    .with_aliases(
429                        node.lookup_alias().map(String::from),
430                        node.stream_alias().map(String::from),
431                    );
432
433                    return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
434                        node: Arc::new(rebuilt),
435                    })));
436                }
437            }
438        }
439
440        // Case 2: Direct LookupJoinNode (re-classify existing predicates)
441        if let LogicalPlan::Extension(ext) = &plan {
442            if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
443                // Only re-classify if there are predicates to work with
444                if !node.pushdown_predicates().is_empty() || !node.local_predicates().is_empty() {
445                    let (pushdown, local) = self.split_for_node(node, &[]);
446                    let inputs = node.inputs();
447                    let rebuilt = LookupJoinNode::new(
448                        inputs[0].clone(),
449                        node.lookup_table_name().to_string(),
450                        node.lookup_schema().clone(),
451                        node.join_keys().to_vec(),
452                        node.join_type(),
453                        pushdown,
454                        node.required_lookup_columns().clone(),
455                        UserDefinedLogicalNodeCore::schema(node).clone(),
456                        node.metadata().clone(),
457                    )
458                    .with_local_predicates(local)
459                    .with_aliases(
460                        node.lookup_alias().map(String::from),
461                        node.stream_alias().map(String::from),
462                    );
463
464                    return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
465                        node: Arc::new(rebuilt),
466                    })));
467                }
468            }
469        }
470
471        Ok(Transformed::no(plan))
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    use std::collections::HashSet;
480
481    use arrow::datatypes::{DataType, Field, Schema};
482    use datafusion::common::DFSchema;
483    use datafusion::logical_expr::col;
484    use datafusion::prelude::lit;
485
486    use crate::datafusion::lookup_join::{
487        JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
488    };
489
490    fn lookup_cols() -> HashSet<String> {
491        HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()])
492    }
493
494    fn stream_cols() -> HashSet<String> {
495        HashSet::from([
496            "order_id".to_string(),
497            "customer_id".to_string(),
498            "amount".to_string(),
499        ])
500    }
501
502    fn classifier() -> PredicateClassifier {
503        PredicateClassifier::new(lookup_cols(), stream_cols(), None, None)
504    }
505
506    fn classifier_with_aliases() -> PredicateClassifier {
507        PredicateClassifier::new(lookup_cols(), stream_cols(), Some("c"), Some("o"))
508    }
509
510    // -----------------------------------------------------------------------
511    // PredicateClassifier tests
512    // -----------------------------------------------------------------------
513
514    #[test]
515    fn test_classify_lookup_only() {
516        let c = classifier();
517        let expr = col("region").eq(lit("US"));
518        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
519    }
520
521    #[test]
522    fn test_classify_stream_only() {
523        let c = classifier();
524        let expr = col("amount").gt(lit(100));
525        assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
526    }
527
528    #[test]
529    fn test_classify_cross_reference() {
530        let c = classifier();
531        // amount (stream) > id (lookup) → cross-reference
532        let expr = col("amount").gt(col("id"));
533        assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
534    }
535
536    #[test]
537    fn test_classify_constant() {
538        let c = classifier();
539        let expr = lit(1).eq(lit(1));
540        assert_eq!(c.classify(&expr), PredicateClass::Constant);
541    }
542
543    #[test]
544    fn test_classify_qualified_lookup_c7() {
545        let c = classifier_with_aliases();
546        // c.name should resolve to lookup via qualified match
547        let expr = Expr::Column(datafusion::common::Column::new(Some::<&str>("c"), "name"))
548            .eq(lit("Alice"));
549        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
550    }
551
552    #[test]
553    fn test_classify_qualified_stream_c7() {
554        let c = classifier_with_aliases();
555        let expr =
556            Expr::Column(datafusion::common::Column::new(Some::<&str>("o"), "amount")).gt(lit(50));
557        assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
558    }
559
560    #[test]
561    fn test_classify_ambiguous_both_sides() {
562        // Column name exists in both sides without qualifier → both flags set
563        let lookup = HashSet::from(["id".to_string()]);
564        let stream = HashSet::from(["id".to_string()]);
565        let c = PredicateClassifier::new(lookup, stream, None, None);
566        let expr = col("id").eq(lit(1));
567        assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
568    }
569
570    #[test]
571    fn test_classify_nested_function() {
572        let c = classifier();
573        // UPPER(name) = 'ALICE' — name is lookup-only
574        let expr = Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
575            func: datafusion::functions::string::upper(),
576            args: vec![col("name")],
577        })
578        .eq(lit("ALICE"));
579        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
580    }
581
582    #[test]
583    fn test_classify_is_null() {
584        let c = classifier();
585        let expr = col("name").is_null();
586        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
587    }
588
589    #[test]
590    fn test_classify_between() {
591        let c = classifier();
592        let expr = Expr::Between(datafusion::logical_expr::expr::Between {
593            expr: Box::new(col("amount")),
594            negated: false,
595            low: Box::new(lit(10)),
596            high: Box::new(lit(100)),
597        });
598        assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
599    }
600
601    #[test]
602    fn test_classify_in_list() {
603        let c = classifier();
604        let expr = col("region").in_list(vec![lit("US"), lit("EU")], false);
605        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
606    }
607
608    // -----------------------------------------------------------------------
609    // split_conjunction tests
610    // -----------------------------------------------------------------------
611
612    #[test]
613    fn test_split_flat_conjunction() {
614        let expr = col("a")
615            .eq(lit(1))
616            .and(col("b").eq(lit(2)))
617            .and(col("c").eq(lit(3)));
618        let parts = split_conjunction(&expr);
619        assert_eq!(parts.len(), 3);
620    }
621
622    #[test]
623    fn test_split_nested_conjunction() {
624        // (A AND B) AND (C AND D)
625        let left = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
626        let right = col("c").eq(lit(3)).and(col("d").eq(lit(4)));
627        let expr = left.and(right);
628        let parts = split_conjunction(&expr);
629        assert_eq!(parts.len(), 4);
630    }
631
632    #[test]
633    fn test_split_single_predicate() {
634        let expr = col("a").eq(lit(1));
635        let parts = split_conjunction(&expr);
636        assert_eq!(parts.len(), 1);
637    }
638
639    #[test]
640    fn test_split_or_not_split() {
641        // OR should NOT be split
642        let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
643        let parts = split_conjunction(&expr);
644        assert_eq!(parts.len(), 1);
645    }
646
647    // -----------------------------------------------------------------------
648    // PredicateSplitterRule integration tests
649    // -----------------------------------------------------------------------
650
651    fn test_metadata() -> LookupTableMetadata {
652        LookupTableMetadata {
653            connector: "postgres-cdc".to_string(),
654            strategy: "replicated".to_string(),
655            pushdown_mode: "auto".to_string(),
656            primary_key: vec!["id".to_string()],
657        }
658    }
659
660    fn test_stream_schema() -> Arc<DFSchema> {
661        Arc::new(
662            DFSchema::try_from(Schema::new(vec![
663                Field::new("order_id", DataType::Int64, false),
664                Field::new("customer_id", DataType::Int64, false),
665                Field::new("amount", DataType::Float64, false),
666            ]))
667            .unwrap(),
668        )
669    }
670
671    fn test_lookup_schema() -> Arc<DFSchema> {
672        Arc::new(
673            DFSchema::try_from(Schema::new(vec![
674                Field::new("id", DataType::Int64, false),
675                Field::new("name", DataType::Utf8, true),
676                Field::new("region", DataType::Utf8, true),
677            ]))
678            .unwrap(),
679        )
680    }
681
682    fn test_output_schema() -> Arc<DFSchema> {
683        Arc::new(
684            DFSchema::try_from(Schema::new(vec![
685                Field::new("order_id", DataType::Int64, false),
686                Field::new("customer_id", DataType::Int64, false),
687                Field::new("amount", DataType::Float64, false),
688                Field::new("id", DataType::Int64, false),
689                Field::new("name", DataType::Utf8, true),
690                Field::new("region", DataType::Utf8, true),
691            ]))
692            .unwrap(),
693        )
694    }
695
696    fn make_lookup_node(join_type: LookupJoinType) -> LookupJoinNode {
697        let stream_schema = test_stream_schema();
698        let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
699            produce_one_row: false,
700            schema: stream_schema,
701        });
702
703        LookupJoinNode::new(
704            input,
705            "customers".to_string(),
706            test_lookup_schema(),
707            vec![JoinKeyPair {
708                stream_expr: col("customer_id"),
709                lookup_column: "id".to_string(),
710            }],
711            join_type,
712            vec![],
713            HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()]),
714            test_output_schema(),
715            test_metadata(),
716        )
717    }
718
719    fn make_filter_over_node(node: LookupJoinNode, predicate: Expr) -> LogicalPlan {
720        let ext = LogicalPlan::Extension(Extension {
721            node: Arc::new(node),
722        });
723        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(ext)).unwrap())
724    }
725
726    fn full_capabilities() -> SourceCapabilitiesRegistry {
727        let mut reg = SourceCapabilitiesRegistry::default();
728        reg.register(
729            "customers".to_string(),
730            PlanSourceCapabilities {
731                pushdown_mode: PlanPushdownMode::Full,
732                eq_columns: HashSet::from([
733                    "id".to_string(),
734                    "name".to_string(),
735                    "region".to_string(),
736                ]),
737                range_columns: HashSet::new(),
738                in_columns: HashSet::new(),
739                supports_null_check: true,
740            },
741        );
742        reg
743    }
744
745    fn no_capabilities() -> SourceCapabilitiesRegistry {
746        SourceCapabilitiesRegistry::default()
747    }
748
749    #[test]
750    fn test_pushdown_inner_join_lookup_only() {
751        let node = make_lookup_node(LookupJoinType::Inner);
752        let filter_pred = col("region").eq(lit("US"));
753        let plan = make_filter_over_node(node, filter_pred);
754
755        let rule = PredicateSplitterRule::new(full_capabilities());
756        let result = rule
757            .rewrite(
758                plan,
759                &datafusion_optimizer::optimizer::OptimizerContext::new(),
760            )
761            .unwrap();
762
763        assert!(result.transformed);
764        if let LogicalPlan::Extension(ext) = &result.data {
765            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
766            assert_eq!(rebuilt.pushdown_predicates().len(), 1);
767            assert_eq!(rebuilt.local_predicates().len(), 0);
768        } else {
769            panic!("Expected Extension node");
770        }
771    }
772
773    #[test]
774    fn test_stream_predicate_stays_local() {
775        let node = make_lookup_node(LookupJoinType::Inner);
776        let filter_pred = col("amount").gt(lit(100));
777        let plan = make_filter_over_node(node, filter_pred);
778
779        let rule = PredicateSplitterRule::new(full_capabilities());
780        let result = rule
781            .rewrite(
782                plan,
783                &datafusion_optimizer::optimizer::OptimizerContext::new(),
784            )
785            .unwrap();
786
787        assert!(result.transformed);
788        if let LogicalPlan::Extension(ext) = &result.data {
789            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
790            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
791            assert_eq!(rebuilt.local_predicates().len(), 1);
792        } else {
793            panic!("Expected Extension node");
794        }
795    }
796
797    #[test]
798    fn test_cross_ref_stays_local() {
799        let node = make_lookup_node(LookupJoinType::Inner);
800        // amount > id  (crosses stream and lookup)
801        let filter_pred = col("amount").gt(col("id"));
802        let plan = make_filter_over_node(node, filter_pred);
803
804        let rule = PredicateSplitterRule::new(full_capabilities());
805        let result = rule
806            .rewrite(
807                plan,
808                &datafusion_optimizer::optimizer::OptimizerContext::new(),
809            )
810            .unwrap();
811
812        assert!(result.transformed);
813        if let LogicalPlan::Extension(ext) = &result.data {
814            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
815            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
816            assert_eq!(rebuilt.local_predicates().len(), 1);
817        } else {
818            panic!("Expected Extension node");
819        }
820    }
821
822    #[test]
823    fn test_pushdown_disabled_keeps_local() {
824        let node = make_lookup_node(LookupJoinType::Inner);
825        let filter_pred = col("region").eq(lit("US"));
826        let plan = make_filter_over_node(node, filter_pred);
827
828        // No capabilities registered → pushdown disabled
829        let rule = PredicateSplitterRule::new(no_capabilities());
830        let result = rule
831            .rewrite(
832                plan,
833                &datafusion_optimizer::optimizer::OptimizerContext::new(),
834            )
835            .unwrap();
836
837        assert!(result.transformed);
838        if let LogicalPlan::Extension(ext) = &result.data {
839            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
840            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
841            assert_eq!(rebuilt.local_predicates().len(), 1);
842        } else {
843            panic!("Expected Extension node");
844        }
845    }
846
847    #[test]
848    fn test_left_join_h10_safety() {
849        // H10: LEFT OUTER lookup-only preds must NOT be pushed down
850        let node = make_lookup_node(LookupJoinType::LeftOuter);
851        let filter_pred = col("region").eq(lit("US"));
852        let plan = make_filter_over_node(node, filter_pred);
853
854        let rule = PredicateSplitterRule::new(full_capabilities());
855        let result = rule
856            .rewrite(
857                plan,
858                &datafusion_optimizer::optimizer::OptimizerContext::new(),
859            )
860            .unwrap();
861
862        assert!(result.transformed);
863        if let LogicalPlan::Extension(ext) = &result.data {
864            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
865            // Should stay local due to H10
866            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
867            assert_eq!(rebuilt.local_predicates().len(), 1);
868        } else {
869            panic!("Expected Extension node");
870        }
871    }
872
873    #[test]
874    fn test_no_filter_no_predicates_passthrough() {
875        let node = make_lookup_node(LookupJoinType::Inner);
876        let plan = LogicalPlan::Extension(Extension {
877            node: Arc::new(node),
878        });
879
880        let rule = PredicateSplitterRule::new(full_capabilities());
881        let result = rule
882            .rewrite(
883                plan,
884                &datafusion_optimizer::optimizer::OptimizerContext::new(),
885            )
886            .unwrap();
887
888        // No predicates to split → no transformation
889        assert!(!result.transformed);
890    }
891
892    #[test]
893    fn test_mixed_conjunction_split() {
894        let node = make_lookup_node(LookupJoinType::Inner);
895        // region = 'US' AND amount > 100
896        let filter_pred = col("region").eq(lit("US")).and(col("amount").gt(lit(100)));
897        let plan = make_filter_over_node(node, filter_pred);
898
899        let rule = PredicateSplitterRule::new(full_capabilities());
900        let result = rule
901            .rewrite(
902                plan,
903                &datafusion_optimizer::optimizer::OptimizerContext::new(),
904            )
905            .unwrap();
906
907        assert!(result.transformed);
908        if let LogicalPlan::Extension(ext) = &result.data {
909            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
910            // region = 'US' → pushdown, amount > 100 → local
911            assert_eq!(rebuilt.pushdown_predicates().len(), 1);
912            assert_eq!(rebuilt.local_predicates().len(), 1);
913        } else {
914            panic!("Expected Extension node");
915        }
916    }
917
918    #[test]
919    fn test_not_eq_stays_local() {
920        let node = make_lookup_node(LookupJoinType::Inner);
921        let filter_pred = col("region").not_eq(lit("US"));
922        let plan = make_filter_over_node(node, filter_pred);
923
924        let rule = PredicateSplitterRule::new(full_capabilities());
925        let result = rule
926            .rewrite(
927                plan,
928                &datafusion_optimizer::optimizer::OptimizerContext::new(),
929            )
930            .unwrap();
931
932        assert!(result.transformed);
933        if let LogicalPlan::Extension(ext) = &result.data {
934            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
935            // NotEq never pushed down
936            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
937            assert_eq!(rebuilt.local_predicates().len(), 1);
938        } else {
939            panic!("Expected Extension node");
940        }
941    }
942
943    #[test]
944    fn test_source_capabilities_registry() {
945        let mut reg = SourceCapabilitiesRegistry::default();
946        assert!(reg.get("foo").is_none());
947
948        reg.register(
949            "foo".to_string(),
950            PlanSourceCapabilities {
951                pushdown_mode: PlanPushdownMode::Full,
952                ..Default::default()
953            },
954        );
955        assert_eq!(
956            reg.get("foo").unwrap().pushdown_mode,
957            PlanPushdownMode::Full
958        );
959    }
960
961    #[test]
962    fn test_plan_source_capabilities_default() {
963        let caps = PlanSourceCapabilities::default();
964        assert_eq!(caps.pushdown_mode, PlanPushdownMode::None);
965        assert!(caps.eq_columns.is_empty());
966        assert!(!caps.supports_null_check);
967    }
968}