Skip to main content

laminar_sql/planner/
predicate_split.rs

1//! Predicate splitting and pushdown for lookup joins.
2//!
3//! This module classifies WHERE/ON predicates in lookup join queries and
4//! splits them into pushdown vs local evaluation categories. It implements
5//! a DataFusion optimizer rule (`PredicateSplitterRule`) that absorbs
6//! filter nodes above `LookupJoinNode` and assigns each predicate to the
7//! correct execution site.
8//!
9//! ## Key Safety Rules
10//!
11//! - **H10 (LEFT JOIN safety):** WHERE-clause predicates on lookup-only
12//!   columns above a `LeftOuter` join must NOT be pushed down — doing so
13//!   changes the semantics by filtering out NULL-extended rows.
14//! - **C7 (qualified columns):** When aliases are present, `col.relation`
15//!   is checked first for unambiguous resolution before falling back to
16//!   unqualified column name matching.
17//! - **`NotEq`** predicates are classified normally but are never pushed
18//!   down (they cannot use equality indexes on the source).
19
20#[allow(clippy::disallowed_types)] // cold path: query planning
21use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23
24use datafusion::logical_expr::logical_plan::LogicalPlan;
25use datafusion::logical_expr::{
26    BinaryExpr, Expr, Extension, Filter, Operator as DfOperator, UserDefinedLogicalNodeCore,
27};
28use datafusion_common::tree_node::Transformed;
29use datafusion_common::Result;
30use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
31
32use crate::datafusion::lookup_join::{LookupJoinNode, LookupJoinType};
33
34// ---------------------------------------------------------------------------
35// Predicate Classification
36// ---------------------------------------------------------------------------
37
38/// Classification of a predicate based on which side(s) it references.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PredicateClass {
41    /// References only lookup table columns — candidate for pushdown.
42    LookupOnly,
43    /// References only stream columns — evaluate locally.
44    StreamOnly,
45    /// References columns from both sides — evaluate locally.
46    CrossReference,
47    /// References no columns (constant expression) — evaluate locally.
48    Constant,
49}
50
51/// Classifies predicates based on column membership.
52///
53/// Uses both unqualified column names and qualified `"alias.col"` names
54/// for resolution (audit C7). When a column has a relation qualifier,
55/// the qualified form is checked first.
56#[derive(Debug)]
57pub struct PredicateClassifier {
58    /// Unqualified lookup column names.
59    lookup_columns: HashSet<String>,
60    /// Unqualified stream column names.
61    stream_columns: HashSet<String>,
62    /// Qualified lookup names: `"alias.col"` or `"table.col"`.
63    lookup_qualified: HashSet<String>,
64    /// Qualified stream names: `"alias.col"` or `"table.col"`.
65    stream_qualified: HashSet<String>,
66}
67
68impl PredicateClassifier {
69    /// Creates a new classifier from column sets.
70    ///
71    /// `lookup_alias` / `stream_alias` are the SQL aliases (e.g., `c` for
72    /// `customers c`). When provided, qualified lookups like `c.name` can
73    /// be resolved unambiguously.
74    #[must_use]
75    pub fn new(
76        lookup_columns: HashSet<String>,
77        stream_columns: HashSet<String>,
78        lookup_alias: Option<&str>,
79        stream_alias: Option<&str>,
80    ) -> Self {
81        let mut lookup_qualified = HashSet::new();
82        let mut stream_qualified = HashSet::new();
83
84        if let Some(alias) = lookup_alias {
85            for col in &lookup_columns {
86                lookup_qualified.insert(format!("{alias}.{col}"));
87            }
88        }
89        if let Some(alias) = stream_alias {
90            for col in &stream_columns {
91                stream_qualified.insert(format!("{alias}.{col}"));
92            }
93        }
94
95        Self {
96            lookup_columns,
97            stream_columns,
98            lookup_qualified,
99            stream_qualified,
100        }
101    }
102
103    /// Classify a predicate expression.
104    #[must_use]
105    pub fn classify(&self, expr: &Expr) -> PredicateClass {
106        let mut has_lookup = false;
107        let mut has_stream = false;
108        self.walk_columns(expr, &mut has_lookup, &mut has_stream);
109
110        match (has_lookup, has_stream) {
111            (true, false) => PredicateClass::LookupOnly,
112            (false, true) => PredicateClass::StreamOnly,
113            (true, true) => PredicateClass::CrossReference,
114            (false, false) => PredicateClass::Constant,
115        }
116    }
117
118    /// Recursively walk an expression to find column references.
119    fn walk_columns(&self, expr: &Expr, has_lookup: &mut bool, has_stream: &mut bool) {
120        match expr {
121            Expr::Column(col) => {
122                // C7: check qualified form first
123                if let Some(relation) = &col.relation {
124                    let qualified = format!("{}.{}", relation, col.name);
125                    if self.lookup_qualified.contains(&qualified) {
126                        *has_lookup = true;
127                        return;
128                    }
129                    if self.stream_qualified.contains(&qualified) {
130                        *has_stream = true;
131                        return;
132                    }
133                }
134                // Fall back to unqualified
135                if self.lookup_columns.contains(&col.name) {
136                    *has_lookup = true;
137                }
138                if self.stream_columns.contains(&col.name) {
139                    *has_stream = true;
140                }
141            }
142            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
143                self.walk_columns(left, has_lookup, has_stream);
144                self.walk_columns(right, has_lookup, has_stream);
145            }
146            Expr::Not(inner)
147            | Expr::IsNull(inner)
148            | Expr::IsNotNull(inner)
149            | Expr::Negative(inner)
150            | Expr::Cast(datafusion::logical_expr::Cast { expr: inner, .. })
151            | Expr::TryCast(datafusion::logical_expr::TryCast { expr: inner, .. }) => {
152                self.walk_columns(inner, has_lookup, has_stream);
153            }
154            Expr::Between(between) => {
155                self.walk_columns(&between.expr, has_lookup, has_stream);
156                self.walk_columns(&between.low, has_lookup, has_stream);
157                self.walk_columns(&between.high, has_lookup, has_stream);
158            }
159            Expr::InList(in_list) => {
160                self.walk_columns(&in_list.expr, has_lookup, has_stream);
161                for item in &in_list.list {
162                    self.walk_columns(item, has_lookup, has_stream);
163                }
164            }
165            Expr::ScalarFunction(func) => {
166                for arg in &func.args {
167                    self.walk_columns(arg, has_lookup, has_stream);
168                }
169            }
170            Expr::Like(like) => {
171                self.walk_columns(&like.expr, has_lookup, has_stream);
172                self.walk_columns(&like.pattern, has_lookup, has_stream);
173            }
174            Expr::Case(case) => {
175                if let Some(operand) = &case.expr {
176                    self.walk_columns(operand, has_lookup, has_stream);
177                }
178                for (when, then) in &case.when_then_expr {
179                    self.walk_columns(when, has_lookup, has_stream);
180                    self.walk_columns(then, has_lookup, has_stream);
181                }
182                if let Some(else_expr) = &case.else_expr {
183                    self.walk_columns(else_expr, has_lookup, has_stream);
184                }
185            }
186            // Literals, placeholders — no columns
187            Expr::Literal(..) | Expr::Placeholder(_) => {}
188            // Catch-all: conservative — mark both sides
189            _ => {
190                *has_lookup = true;
191                *has_stream = true;
192            }
193        }
194    }
195}
196
197// ---------------------------------------------------------------------------
198// Source Capabilities
199// ---------------------------------------------------------------------------
200
201/// Mode describing how far predicates can be pushed to a source.
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub enum PlanPushdownMode {
204    /// Full predicate pushdown (eq, range, in, null checks).
205    Full,
206    /// Only key equality predicates.
207    KeyOnly,
208    /// No pushdown at all.
209    None,
210}
211
212/// Describes a source's pushdown capabilities for the optimizer.
213#[derive(Debug, Clone)]
214pub struct PlanSourceCapabilities {
215    /// Overall pushdown mode.
216    pub pushdown_mode: PlanPushdownMode,
217    /// Columns that support equality pushdown.
218    pub eq_columns: HashSet<String>,
219    /// Columns that support range pushdown.
220    pub range_columns: HashSet<String>,
221    /// Columns that support IN-list pushdown.
222    pub in_columns: HashSet<String>,
223    /// Whether the source supports IS NULL / IS NOT NULL checks.
224    pub supports_null_check: bool,
225}
226
227impl Default for PlanSourceCapabilities {
228    fn default() -> Self {
229        Self {
230            pushdown_mode: PlanPushdownMode::None,
231            eq_columns: HashSet::new(),
232            range_columns: HashSet::new(),
233            in_columns: HashSet::new(),
234            supports_null_check: false,
235        }
236    }
237}
238
239/// Registry mapping lookup table names to their source capabilities.
240#[derive(Debug, Default)]
241pub struct SourceCapabilitiesRegistry {
242    capabilities: HashMap<String, PlanSourceCapabilities>,
243}
244
245impl SourceCapabilitiesRegistry {
246    /// Register capabilities for a lookup table.
247    pub fn register(&mut self, table_name: String, caps: PlanSourceCapabilities) {
248        self.capabilities.insert(table_name, caps);
249    }
250
251    /// Get capabilities for a lookup table.
252    #[must_use]
253    pub fn get(&self, table_name: &str) -> Option<&PlanSourceCapabilities> {
254        self.capabilities.get(table_name)
255    }
256}
257
258// ---------------------------------------------------------------------------
259// Conjunction Splitting
260// ---------------------------------------------------------------------------
261
262/// Splits a conjunction (AND chain) into individual predicates.
263///
264/// `A AND B AND C` → `[A, B, C]`.
265/// OR expressions and non-AND binary expressions are kept as single items.
266#[must_use]
267pub fn split_conjunction(expr: &Expr) -> Vec<Expr> {
268    match expr {
269        Expr::BinaryExpr(BinaryExpr {
270            left,
271            op: DfOperator::And,
272            right,
273        }) => {
274            let mut parts = split_conjunction(left);
275            parts.extend(split_conjunction(right));
276            parts
277        }
278        other => vec![other.clone()],
279    }
280}
281
282// ---------------------------------------------------------------------------
283// Optimizer Rule
284// ---------------------------------------------------------------------------
285
286/// DataFusion optimizer rule that splits predicates for lookup joins.
287///
288/// Runs `TopDown` to catch `Filter` nodes above `LookupJoinNode` first.
289///
290/// Two cases:
291/// 1. **Filter above LookupJoinNode** — absorb the filter, classify
292///    each conjunct, and assign to pushdown or local.
293/// 2. **Direct LookupJoinNode** — re-classify existing pushdown predicates
294///    (e.g., after a previous pass added them).
295#[derive(Debug)]
296pub struct PredicateSplitterRule {
297    /// Per-table source capabilities.
298    capabilities: SourceCapabilitiesRegistry,
299}
300
301impl PredicateSplitterRule {
302    /// Creates a new rule with the given capabilities registry.
303    #[must_use]
304    pub fn new(capabilities: SourceCapabilitiesRegistry) -> Self {
305        Self { capabilities }
306    }
307
308    /// Split predicates for a `LookupJoinNode`, given a list of predicates
309    /// that come from an absorbed `Filter` (if any) plus the node's
310    /// existing predicates.
311    fn split_for_node(
312        &self,
313        node: &LookupJoinNode,
314        filter_predicates: &[Expr],
315    ) -> (Vec<Expr>, Vec<Expr>) {
316        // Build column sets from schemas
317        let lookup_columns: HashSet<String> = node
318            .lookup_schema()
319            .fields()
320            .iter()
321            .map(|f| f.name().clone())
322            .collect();
323
324        let input_schema = node.inputs()[0].schema();
325        let stream_columns: HashSet<String> = input_schema
326            .fields()
327            .iter()
328            .map(|f| f.name().clone())
329            .collect();
330
331        let classifier = PredicateClassifier::new(
332            lookup_columns,
333            stream_columns,
334            node.lookup_alias(),
335            node.stream_alias(),
336        );
337
338        let caps = self.capabilities.get(node.lookup_table_name());
339        let pushdown_disabled = caps.is_none_or(|c| c.pushdown_mode == PlanPushdownMode::None);
340
341        let is_left_outer = node.join_type() == LookupJoinType::LeftOuter;
342
343        let mut pushdown = Vec::new();
344        let mut local = Vec::new();
345
346        // Include existing predicates from the node
347        let all_predicates = node
348            .pushdown_predicates()
349            .iter()
350            .chain(node.local_predicates().iter())
351            .chain(filter_predicates.iter())
352            .cloned();
353
354        for pred in all_predicates {
355            let class = classifier.classify(&pred);
356
357            // NotEq predicates never push down
358            let has_not_eq = contains_not_eq(&pred);
359
360            match class {
361                PredicateClass::LookupOnly => {
362                    // H10: LEFT OUTER WHERE-clause lookup-only preds stay local
363                    if is_left_outer || pushdown_disabled || has_not_eq {
364                        local.push(pred);
365                    } else {
366                        pushdown.push(pred);
367                    }
368                }
369                PredicateClass::StreamOnly
370                | PredicateClass::CrossReference
371                | PredicateClass::Constant => {
372                    local.push(pred);
373                }
374            }
375        }
376
377        (pushdown, local)
378    }
379}
380
381/// Check if an expression contains a `NotEq` operator.
382fn contains_not_eq(expr: &Expr) -> bool {
383    match expr {
384        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
385            *op == DfOperator::NotEq || contains_not_eq(left) || contains_not_eq(right)
386        }
387        Expr::Not(inner) => contains_not_eq(inner),
388        _ => false,
389    }
390}
391
392impl OptimizerRule for PredicateSplitterRule {
393    fn name(&self) -> &'static str {
394        "predicate_splitter"
395    }
396
397    fn apply_order(&self) -> Option<ApplyOrder> {
398        Some(ApplyOrder::TopDown)
399    }
400
401    fn rewrite(
402        &self,
403        plan: LogicalPlan,
404        _config: &dyn OptimizerConfig,
405    ) -> Result<Transformed<LogicalPlan>> {
406        // Case 1: Filter above a LookupJoinNode
407        if let LogicalPlan::Filter(Filter {
408            predicate, input, ..
409        }) = &plan
410        {
411            if let LogicalPlan::Extension(ext) = input.as_ref() {
412                if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
413                    let filter_preds = split_conjunction(predicate);
414                    let (pushdown, local) = self.split_for_node(node, &filter_preds);
415
416                    let inputs = node.inputs();
417                    let rebuilt = LookupJoinNode::new(
418                        inputs[0].clone(),
419                        node.lookup_table_name().to_string(),
420                        node.lookup_schema().clone(),
421                        node.join_keys().to_vec(),
422                        node.join_type(),
423                        pushdown,
424                        node.required_lookup_columns().clone(),
425                        UserDefinedLogicalNodeCore::schema(node).clone(),
426                        node.metadata().clone(),
427                    )
428                    .with_local_predicates(local)
429                    .with_aliases(
430                        node.lookup_alias().map(String::from),
431                        node.stream_alias().map(String::from),
432                    );
433
434                    return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
435                        node: Arc::new(rebuilt),
436                    })));
437                }
438            }
439        }
440
441        // Case 2: Direct LookupJoinNode (re-classify existing predicates)
442        if let LogicalPlan::Extension(ext) = &plan {
443            if let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() {
444                // Only re-classify if there are predicates to work with
445                if !node.pushdown_predicates().is_empty() || !node.local_predicates().is_empty() {
446                    let (pushdown, local) = self.split_for_node(node, &[]);
447                    let inputs = node.inputs();
448                    let rebuilt = LookupJoinNode::new(
449                        inputs[0].clone(),
450                        node.lookup_table_name().to_string(),
451                        node.lookup_schema().clone(),
452                        node.join_keys().to_vec(),
453                        node.join_type(),
454                        pushdown,
455                        node.required_lookup_columns().clone(),
456                        UserDefinedLogicalNodeCore::schema(node).clone(),
457                        node.metadata().clone(),
458                    )
459                    .with_local_predicates(local)
460                    .with_aliases(
461                        node.lookup_alias().map(String::from),
462                        node.stream_alias().map(String::from),
463                    );
464
465                    return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
466                        node: Arc::new(rebuilt),
467                    })));
468                }
469            }
470        }
471
472        Ok(Transformed::no(plan))
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[allow(clippy::disallowed_types)] // cold path: query planning
481    use std::collections::HashSet;
482
483    use arrow::datatypes::{DataType, Field, Schema};
484    use datafusion::common::DFSchema;
485    use datafusion::logical_expr::col;
486    use datafusion::prelude::lit;
487
488    use crate::datafusion::lookup_join::{
489        JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
490    };
491
492    fn lookup_cols() -> HashSet<String> {
493        HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()])
494    }
495
496    fn stream_cols() -> HashSet<String> {
497        HashSet::from([
498            "order_id".to_string(),
499            "customer_id".to_string(),
500            "amount".to_string(),
501        ])
502    }
503
504    fn classifier() -> PredicateClassifier {
505        PredicateClassifier::new(lookup_cols(), stream_cols(), None, None)
506    }
507
508    fn classifier_with_aliases() -> PredicateClassifier {
509        PredicateClassifier::new(lookup_cols(), stream_cols(), Some("c"), Some("o"))
510    }
511
512    // -----------------------------------------------------------------------
513    // PredicateClassifier tests
514    // -----------------------------------------------------------------------
515
516    #[test]
517    fn test_classify_lookup_only() {
518        let c = classifier();
519        let expr = col("region").eq(lit("US"));
520        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
521    }
522
523    #[test]
524    fn test_classify_stream_only() {
525        let c = classifier();
526        let expr = col("amount").gt(lit(100));
527        assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
528    }
529
530    #[test]
531    fn test_classify_cross_reference() {
532        let c = classifier();
533        // amount (stream) > id (lookup) → cross-reference
534        let expr = col("amount").gt(col("id"));
535        assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
536    }
537
538    #[test]
539    fn test_classify_constant() {
540        let c = classifier();
541        let expr = lit(1).eq(lit(1));
542        assert_eq!(c.classify(&expr), PredicateClass::Constant);
543    }
544
545    #[test]
546    fn test_classify_qualified_lookup_c7() {
547        let c = classifier_with_aliases();
548        // c.name should resolve to lookup via qualified match
549        let expr = Expr::Column(datafusion::common::Column::new(Some::<&str>("c"), "name"))
550            .eq(lit("Alice"));
551        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
552    }
553
554    #[test]
555    fn test_classify_qualified_stream_c7() {
556        let c = classifier_with_aliases();
557        let expr =
558            Expr::Column(datafusion::common::Column::new(Some::<&str>("o"), "amount")).gt(lit(50));
559        assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
560    }
561
562    #[test]
563    fn test_classify_ambiguous_both_sides() {
564        // Column name exists in both sides without qualifier → both flags set
565        let lookup = HashSet::from(["id".to_string()]);
566        let stream = HashSet::from(["id".to_string()]);
567        let c = PredicateClassifier::new(lookup, stream, None, None);
568        let expr = col("id").eq(lit(1));
569        assert_eq!(c.classify(&expr), PredicateClass::CrossReference);
570    }
571
572    #[test]
573    fn test_classify_nested_function() {
574        let c = classifier();
575        // UPPER(name) = 'ALICE' — name is lookup-only
576        let expr = Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
577            func: datafusion::functions::string::upper(),
578            args: vec![col("name")],
579        })
580        .eq(lit("ALICE"));
581        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
582    }
583
584    #[test]
585    fn test_classify_is_null() {
586        let c = classifier();
587        let expr = col("name").is_null();
588        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
589    }
590
591    #[test]
592    fn test_classify_between() {
593        let c = classifier();
594        let expr = Expr::Between(datafusion::logical_expr::expr::Between {
595            expr: Box::new(col("amount")),
596            negated: false,
597            low: Box::new(lit(10)),
598            high: Box::new(lit(100)),
599        });
600        assert_eq!(c.classify(&expr), PredicateClass::StreamOnly);
601    }
602
603    #[test]
604    fn test_classify_in_list() {
605        let c = classifier();
606        let expr = col("region").in_list(vec![lit("US"), lit("EU")], false);
607        assert_eq!(c.classify(&expr), PredicateClass::LookupOnly);
608    }
609
610    // -----------------------------------------------------------------------
611    // split_conjunction tests
612    // -----------------------------------------------------------------------
613
614    #[test]
615    fn test_split_flat_conjunction() {
616        let expr = col("a")
617            .eq(lit(1))
618            .and(col("b").eq(lit(2)))
619            .and(col("c").eq(lit(3)));
620        let parts = split_conjunction(&expr);
621        assert_eq!(parts.len(), 3);
622    }
623
624    #[test]
625    fn test_split_nested_conjunction() {
626        // (A AND B) AND (C AND D)
627        let left = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
628        let right = col("c").eq(lit(3)).and(col("d").eq(lit(4)));
629        let expr = left.and(right);
630        let parts = split_conjunction(&expr);
631        assert_eq!(parts.len(), 4);
632    }
633
634    #[test]
635    fn test_split_single_predicate() {
636        let expr = col("a").eq(lit(1));
637        let parts = split_conjunction(&expr);
638        assert_eq!(parts.len(), 1);
639    }
640
641    #[test]
642    fn test_split_or_not_split() {
643        // OR should NOT be split
644        let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
645        let parts = split_conjunction(&expr);
646        assert_eq!(parts.len(), 1);
647    }
648
649    // -----------------------------------------------------------------------
650    // PredicateSplitterRule integration tests
651    // -----------------------------------------------------------------------
652
653    fn test_metadata() -> LookupTableMetadata {
654        LookupTableMetadata {
655            connector: "postgres-cdc".to_string(),
656            strategy: "replicated".to_string(),
657            pushdown_mode: "auto".to_string(),
658            primary_key: vec!["id".to_string()],
659        }
660    }
661
662    fn test_stream_schema() -> Arc<DFSchema> {
663        Arc::new(
664            DFSchema::try_from(Schema::new(vec![
665                Field::new("order_id", DataType::Int64, false),
666                Field::new("customer_id", DataType::Int64, false),
667                Field::new("amount", DataType::Float64, false),
668            ]))
669            .unwrap(),
670        )
671    }
672
673    fn test_lookup_schema() -> Arc<DFSchema> {
674        Arc::new(
675            DFSchema::try_from(Schema::new(vec![
676                Field::new("id", DataType::Int64, false),
677                Field::new("name", DataType::Utf8, true),
678                Field::new("region", DataType::Utf8, true),
679            ]))
680            .unwrap(),
681        )
682    }
683
684    fn test_output_schema() -> Arc<DFSchema> {
685        Arc::new(
686            DFSchema::try_from(Schema::new(vec![
687                Field::new("order_id", DataType::Int64, false),
688                Field::new("customer_id", DataType::Int64, false),
689                Field::new("amount", DataType::Float64, false),
690                Field::new("id", DataType::Int64, false),
691                Field::new("name", DataType::Utf8, true),
692                Field::new("region", DataType::Utf8, true),
693            ]))
694            .unwrap(),
695        )
696    }
697
698    fn make_lookup_node(join_type: LookupJoinType) -> LookupJoinNode {
699        let stream_schema = test_stream_schema();
700        let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
701            produce_one_row: false,
702            schema: stream_schema,
703        });
704
705        LookupJoinNode::new(
706            input,
707            "customers".to_string(),
708            test_lookup_schema(),
709            vec![JoinKeyPair {
710                stream_expr: col("customer_id"),
711                lookup_column: "id".to_string(),
712            }],
713            join_type,
714            vec![],
715            HashSet::from(["id".to_string(), "name".to_string(), "region".to_string()]),
716            test_output_schema(),
717            test_metadata(),
718        )
719    }
720
721    fn make_filter_over_node(node: LookupJoinNode, predicate: Expr) -> LogicalPlan {
722        let ext = LogicalPlan::Extension(Extension {
723            node: Arc::new(node),
724        });
725        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(ext)).unwrap())
726    }
727
728    fn full_capabilities() -> SourceCapabilitiesRegistry {
729        let mut reg = SourceCapabilitiesRegistry::default();
730        reg.register(
731            "customers".to_string(),
732            PlanSourceCapabilities {
733                pushdown_mode: PlanPushdownMode::Full,
734                eq_columns: HashSet::from([
735                    "id".to_string(),
736                    "name".to_string(),
737                    "region".to_string(),
738                ]),
739                range_columns: HashSet::new(),
740                in_columns: HashSet::new(),
741                supports_null_check: true,
742            },
743        );
744        reg
745    }
746
747    fn no_capabilities() -> SourceCapabilitiesRegistry {
748        SourceCapabilitiesRegistry::default()
749    }
750
751    #[test]
752    fn test_pushdown_inner_join_lookup_only() {
753        let node = make_lookup_node(LookupJoinType::Inner);
754        let filter_pred = col("region").eq(lit("US"));
755        let plan = make_filter_over_node(node, filter_pred);
756
757        let rule = PredicateSplitterRule::new(full_capabilities());
758        let result = rule
759            .rewrite(
760                plan,
761                &datafusion_optimizer::optimizer::OptimizerContext::new(),
762            )
763            .unwrap();
764
765        assert!(result.transformed);
766        if let LogicalPlan::Extension(ext) = &result.data {
767            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
768            assert_eq!(rebuilt.pushdown_predicates().len(), 1);
769            assert_eq!(rebuilt.local_predicates().len(), 0);
770        } else {
771            panic!("Expected Extension node");
772        }
773    }
774
775    #[test]
776    fn test_stream_predicate_stays_local() {
777        let node = make_lookup_node(LookupJoinType::Inner);
778        let filter_pred = col("amount").gt(lit(100));
779        let plan = make_filter_over_node(node, filter_pred);
780
781        let rule = PredicateSplitterRule::new(full_capabilities());
782        let result = rule
783            .rewrite(
784                plan,
785                &datafusion_optimizer::optimizer::OptimizerContext::new(),
786            )
787            .unwrap();
788
789        assert!(result.transformed);
790        if let LogicalPlan::Extension(ext) = &result.data {
791            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
792            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
793            assert_eq!(rebuilt.local_predicates().len(), 1);
794        } else {
795            panic!("Expected Extension node");
796        }
797    }
798
799    #[test]
800    fn test_cross_ref_stays_local() {
801        let node = make_lookup_node(LookupJoinType::Inner);
802        // amount > id  (crosses stream and lookup)
803        let filter_pred = col("amount").gt(col("id"));
804        let plan = make_filter_over_node(node, filter_pred);
805
806        let rule = PredicateSplitterRule::new(full_capabilities());
807        let result = rule
808            .rewrite(
809                plan,
810                &datafusion_optimizer::optimizer::OptimizerContext::new(),
811            )
812            .unwrap();
813
814        assert!(result.transformed);
815        if let LogicalPlan::Extension(ext) = &result.data {
816            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
817            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
818            assert_eq!(rebuilt.local_predicates().len(), 1);
819        } else {
820            panic!("Expected Extension node");
821        }
822    }
823
824    #[test]
825    fn test_pushdown_disabled_keeps_local() {
826        let node = make_lookup_node(LookupJoinType::Inner);
827        let filter_pred = col("region").eq(lit("US"));
828        let plan = make_filter_over_node(node, filter_pred);
829
830        // No capabilities registered → pushdown disabled
831        let rule = PredicateSplitterRule::new(no_capabilities());
832        let result = rule
833            .rewrite(
834                plan,
835                &datafusion_optimizer::optimizer::OptimizerContext::new(),
836            )
837            .unwrap();
838
839        assert!(result.transformed);
840        if let LogicalPlan::Extension(ext) = &result.data {
841            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
842            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
843            assert_eq!(rebuilt.local_predicates().len(), 1);
844        } else {
845            panic!("Expected Extension node");
846        }
847    }
848
849    #[test]
850    fn test_left_join_h10_safety() {
851        // H10: LEFT OUTER lookup-only preds must NOT be pushed down
852        let node = make_lookup_node(LookupJoinType::LeftOuter);
853        let filter_pred = col("region").eq(lit("US"));
854        let plan = make_filter_over_node(node, filter_pred);
855
856        let rule = PredicateSplitterRule::new(full_capabilities());
857        let result = rule
858            .rewrite(
859                plan,
860                &datafusion_optimizer::optimizer::OptimizerContext::new(),
861            )
862            .unwrap();
863
864        assert!(result.transformed);
865        if let LogicalPlan::Extension(ext) = &result.data {
866            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
867            // Should stay local due to H10
868            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
869            assert_eq!(rebuilt.local_predicates().len(), 1);
870        } else {
871            panic!("Expected Extension node");
872        }
873    }
874
875    #[test]
876    fn test_no_filter_no_predicates_passthrough() {
877        let node = make_lookup_node(LookupJoinType::Inner);
878        let plan = LogicalPlan::Extension(Extension {
879            node: Arc::new(node),
880        });
881
882        let rule = PredicateSplitterRule::new(full_capabilities());
883        let result = rule
884            .rewrite(
885                plan,
886                &datafusion_optimizer::optimizer::OptimizerContext::new(),
887            )
888            .unwrap();
889
890        // No predicates to split → no transformation
891        assert!(!result.transformed);
892    }
893
894    #[test]
895    fn test_mixed_conjunction_split() {
896        let node = make_lookup_node(LookupJoinType::Inner);
897        // region = 'US' AND amount > 100
898        let filter_pred = col("region").eq(lit("US")).and(col("amount").gt(lit(100)));
899        let plan = make_filter_over_node(node, filter_pred);
900
901        let rule = PredicateSplitterRule::new(full_capabilities());
902        let result = rule
903            .rewrite(
904                plan,
905                &datafusion_optimizer::optimizer::OptimizerContext::new(),
906            )
907            .unwrap();
908
909        assert!(result.transformed);
910        if let LogicalPlan::Extension(ext) = &result.data {
911            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
912            // region = 'US' → pushdown, amount > 100 → local
913            assert_eq!(rebuilt.pushdown_predicates().len(), 1);
914            assert_eq!(rebuilt.local_predicates().len(), 1);
915        } else {
916            panic!("Expected Extension node");
917        }
918    }
919
920    #[test]
921    fn test_not_eq_stays_local() {
922        let node = make_lookup_node(LookupJoinType::Inner);
923        let filter_pred = col("region").not_eq(lit("US"));
924        let plan = make_filter_over_node(node, filter_pred);
925
926        let rule = PredicateSplitterRule::new(full_capabilities());
927        let result = rule
928            .rewrite(
929                plan,
930                &datafusion_optimizer::optimizer::OptimizerContext::new(),
931            )
932            .unwrap();
933
934        assert!(result.transformed);
935        if let LogicalPlan::Extension(ext) = &result.data {
936            let rebuilt = ext.node.as_any().downcast_ref::<LookupJoinNode>().unwrap();
937            // NotEq never pushed down
938            assert_eq!(rebuilt.pushdown_predicates().len(), 0);
939            assert_eq!(rebuilt.local_predicates().len(), 1);
940        } else {
941            panic!("Expected Extension node");
942        }
943    }
944
945    #[test]
946    fn test_source_capabilities_registry() {
947        let mut reg = SourceCapabilitiesRegistry::default();
948        assert!(reg.get("foo").is_none());
949
950        reg.register(
951            "foo".to_string(),
952            PlanSourceCapabilities {
953                pushdown_mode: PlanPushdownMode::Full,
954                ..Default::default()
955            },
956        );
957        assert_eq!(
958            reg.get("foo").unwrap().pushdown_mode,
959            PlanPushdownMode::Full
960        );
961    }
962
963    #[test]
964    fn test_plan_source_capabilities_default() {
965        let caps = PlanSourceCapabilities::default();
966        assert_eq!(caps.pushdown_mode, PlanPushdownMode::None);
967        assert!(caps.eq_columns.is_empty());
968        assert!(!caps.supports_null_check);
969    }
970}