Skip to main content

laminar_sql/datafusion/
lookup_join.rs

1//! `LookupJoinNode` — custom DataFusion logical plan node for lookup joins.
2//!
3//! This node represents a join between a streaming input and a registered
4//! lookup table. It is produced by the `LookupJoinRewriteRule` optimizer
5//! rule when a standard JOIN references a registered lookup table.
6
7use std::collections::HashSet;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use datafusion::common::DFSchemaRef;
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Expr, UserDefinedLogicalNodeCore};
15use datafusion_common::Result;
16
17/// Join type for lookup joins.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
19pub enum LookupJoinType {
20    /// Inner join — only emit rows with a match.
21    Inner,
22    /// Left outer join — emit all stream rows, NULLs for non-matches.
23    LeftOuter,
24}
25
26impl fmt::Display for LookupJoinType {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            Self::Inner => write!(f, "Inner"),
30            Self::LeftOuter => write!(f, "LeftOuter"),
31        }
32    }
33}
34
35/// A pair of expressions defining how stream keys map to lookup columns.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct JoinKeyPair {
38    /// Expression on the stream side (e.g., `stream.customer_id`).
39    pub stream_expr: Expr,
40    /// Column name on the lookup table side.
41    pub lookup_column: String,
42}
43
44/// Metadata about a lookup table for plan construction.
45#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
46pub struct LookupTableMetadata {
47    /// Connector type (e.g., "postgres-cdc").
48    pub connector: String,
49    /// Lookup strategy (e.g., "replicated").
50    pub strategy: String,
51    /// Pushdown mode (e.g., "auto").
52    pub pushdown_mode: String,
53    /// Primary key column names.
54    pub primary_key: Vec<String>,
55}
56
57/// Custom logical plan node for a lookup join.
58///
59/// Represents a join between a streaming input plan and a lookup table.
60/// The lookup table is not a DataFusion table; it is resolved at execution
61/// time via the lookup source connector.
62#[derive(Debug, Clone)]
63pub struct LookupJoinNode {
64    /// The streaming input plan.
65    input: Arc<LogicalPlan>,
66    /// Name of the lookup table.
67    lookup_table: String,
68    /// Schema of the lookup table columns.
69    lookup_schema: DFSchemaRef,
70    /// Join key pairs (stream expression -> lookup column).
71    join_keys: Vec<JoinKeyPair>,
72    /// Join type (Inner or LeftOuter).
73    join_type: LookupJoinType,
74    /// Predicates to push down to the lookup source.
75    pushdown_predicates: Vec<Expr>,
76    /// Predicates evaluated locally after the join.
77    local_predicates: Vec<Expr>,
78    /// Required columns from the lookup table.
79    required_lookup_columns: HashSet<String>,
80    /// Combined output schema (stream + lookup columns).
81    output_schema: DFSchemaRef,
82    /// Metadata about the lookup table.
83    metadata: LookupTableMetadata,
84    /// Alias for the lookup table (for qualified column resolution).
85    lookup_alias: Option<String>,
86    /// Alias for the stream input (for qualified column resolution).
87    stream_alias: Option<String>,
88}
89
90impl PartialEq for LookupJoinNode {
91    fn eq(&self, other: &Self) -> bool {
92        self.lookup_table == other.lookup_table
93            && self.join_keys == other.join_keys
94            && self.join_type == other.join_type
95            && self.pushdown_predicates == other.pushdown_predicates
96            && self.local_predicates == other.local_predicates
97            && self.required_lookup_columns == other.required_lookup_columns
98            && self.metadata == other.metadata
99    }
100}
101
102impl Eq for LookupJoinNode {}
103
104impl Hash for LookupJoinNode {
105    fn hash<H: Hasher>(&self, state: &mut H) {
106        self.lookup_table.hash(state);
107        self.join_keys.hash(state);
108        self.join_type.hash(state);
109        self.pushdown_predicates.hash(state);
110        self.local_predicates.hash(state);
111        self.metadata.hash(state);
112        // HashSet doesn't implement Hash; hash sorted elements instead
113        let mut cols: Vec<&String> = self.required_lookup_columns.iter().collect();
114        cols.sort();
115        cols.hash(state);
116    }
117}
118
119impl PartialOrd for LookupJoinNode {
120    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
121        self.lookup_table.partial_cmp(&other.lookup_table)
122    }
123}
124
125impl LookupJoinNode {
126    /// Creates a new lookup join node.
127    #[must_use]
128    #[allow(clippy::too_many_arguments)]
129    pub fn new(
130        input: LogicalPlan,
131        lookup_table: String,
132        lookup_schema: DFSchemaRef,
133        join_keys: Vec<JoinKeyPair>,
134        join_type: LookupJoinType,
135        pushdown_predicates: Vec<Expr>,
136        required_lookup_columns: HashSet<String>,
137        output_schema: DFSchemaRef,
138        metadata: LookupTableMetadata,
139    ) -> Self {
140        Self {
141            input: Arc::new(input),
142            lookup_table,
143            lookup_schema,
144            join_keys,
145            join_type,
146            pushdown_predicates,
147            local_predicates: vec![],
148            required_lookup_columns,
149            output_schema,
150            metadata,
151            lookup_alias: None,
152            stream_alias: None,
153        }
154    }
155
156    /// Sets predicates to be evaluated locally after the join.
157    #[must_use]
158    pub fn with_local_predicates(mut self, predicates: Vec<Expr>) -> Self {
159        self.local_predicates = predicates;
160        self
161    }
162
163    /// Sets table aliases for qualified column resolution.
164    #[must_use]
165    pub fn with_aliases(
166        mut self,
167        lookup_alias: Option<String>,
168        stream_alias: Option<String>,
169    ) -> Self {
170        self.lookup_alias = lookup_alias;
171        self.stream_alias = stream_alias;
172        self
173    }
174
175    /// Returns the lookup table name.
176    #[must_use]
177    pub fn lookup_table_name(&self) -> &str {
178        &self.lookup_table
179    }
180
181    /// Returns the join key pairs.
182    #[must_use]
183    pub fn join_keys(&self) -> &[JoinKeyPair] {
184        &self.join_keys
185    }
186
187    /// Returns the join type.
188    #[must_use]
189    pub fn join_type(&self) -> LookupJoinType {
190        self.join_type
191    }
192
193    /// Returns the pushdown predicates.
194    #[must_use]
195    pub fn pushdown_predicates(&self) -> &[Expr] {
196        &self.pushdown_predicates
197    }
198
199    /// Returns the required lookup columns.
200    #[must_use]
201    pub fn required_lookup_columns(&self) -> &HashSet<String> {
202        &self.required_lookup_columns
203    }
204
205    /// Returns the lookup table metadata.
206    #[must_use]
207    pub fn metadata(&self) -> &LookupTableMetadata {
208        &self.metadata
209    }
210
211    /// Returns the lookup table schema.
212    #[must_use]
213    pub fn lookup_schema(&self) -> &DFSchemaRef {
214        &self.lookup_schema
215    }
216
217    /// Returns the local predicates (evaluated after the join).
218    #[must_use]
219    pub fn local_predicates(&self) -> &[Expr] {
220        &self.local_predicates
221    }
222
223    /// Returns the lookup table alias.
224    #[must_use]
225    pub fn lookup_alias(&self) -> Option<&str> {
226        self.lookup_alias.as_deref()
227    }
228
229    /// Returns the stream input alias.
230    #[must_use]
231    pub fn stream_alias(&self) -> Option<&str> {
232        self.stream_alias.as_deref()
233    }
234}
235
236impl UserDefinedLogicalNodeCore for LookupJoinNode {
237    fn name(&self) -> &'static str {
238        "LookupJoin"
239    }
240
241    fn inputs(&self) -> Vec<&LogicalPlan> {
242        vec![&self.input]
243    }
244
245    fn schema(&self) -> &DFSchemaRef {
246        &self.output_schema
247    }
248
249    fn expressions(&self) -> Vec<Expr> {
250        self.join_keys
251            .iter()
252            .map(|k| k.stream_expr.clone())
253            .chain(self.pushdown_predicates.clone())
254            .chain(self.local_predicates.clone())
255            .collect()
256    }
257
258    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
259        let keys: Vec<String> = self
260            .join_keys
261            .iter()
262            .map(|k| format!("{}={}", k.stream_expr, k.lookup_column))
263            .collect();
264        write!(
265            f,
266            "LookupJoin: table={}, keys=[{}], type={}, pushdown={}, local={}",
267            self.lookup_table,
268            keys.join(", "),
269            self.join_type,
270            self.pushdown_predicates.len(),
271            self.local_predicates.len(),
272        )
273    }
274
275    fn with_exprs_and_inputs(
276        &self,
277        exprs: Vec<Expr>,
278        mut inputs: Vec<LogicalPlan>,
279    ) -> Result<Self> {
280        let input = inputs.swap_remove(0);
281
282        // Split expressions: keys | pushdown predicates | local predicates
283        let num_keys = self.join_keys.len();
284        let num_pushdown = self.pushdown_predicates.len();
285        let (key_exprs, rest) = exprs.split_at(num_keys.min(exprs.len()));
286        let (pushdown_exprs, local_exprs) = rest.split_at(num_pushdown.min(rest.len()));
287
288        let join_keys: Vec<JoinKeyPair> = key_exprs
289            .iter()
290            .zip(self.join_keys.iter())
291            .map(|(expr, old)| JoinKeyPair {
292                stream_expr: expr.clone(),
293                lookup_column: old.lookup_column.clone(),
294            })
295            .collect();
296
297        Ok(Self {
298            input: Arc::new(input),
299            lookup_table: self.lookup_table.clone(),
300            lookup_schema: Arc::clone(&self.lookup_schema),
301            join_keys,
302            join_type: self.join_type,
303            pushdown_predicates: pushdown_exprs.to_vec(),
304            local_predicates: local_exprs.to_vec(),
305            required_lookup_columns: self.required_lookup_columns.clone(),
306            output_schema: Arc::clone(&self.output_schema),
307            metadata: self.metadata.clone(),
308            lookup_alias: self.lookup_alias.clone(),
309            stream_alias: self.stream_alias.clone(),
310        })
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use std::fmt::Write;
318
319    use arrow::datatypes::{DataType, Field, Schema};
320    use datafusion::common::DFSchema;
321    use datafusion::logical_expr::col;
322
323    fn test_stream_schema() -> DFSchemaRef {
324        Arc::new(
325            DFSchema::try_from(Schema::new(vec![
326                Field::new("order_id", DataType::Int64, false),
327                Field::new("customer_id", DataType::Int64, false),
328                Field::new("amount", DataType::Float64, false),
329            ]))
330            .unwrap(),
331        )
332    }
333
334    fn test_lookup_schema() -> DFSchemaRef {
335        Arc::new(
336            DFSchema::try_from(Schema::new(vec![
337                Field::new("id", DataType::Int64, false),
338                Field::new("name", DataType::Utf8, true),
339                Field::new("region", DataType::Utf8, true),
340            ]))
341            .unwrap(),
342        )
343    }
344
345    fn test_output_schema() -> DFSchemaRef {
346        Arc::new(
347            DFSchema::try_from(Schema::new(vec![
348                Field::new("order_id", DataType::Int64, false),
349                Field::new("customer_id", DataType::Int64, false),
350                Field::new("amount", DataType::Float64, false),
351                Field::new("id", DataType::Int64, false),
352                Field::new("name", DataType::Utf8, true),
353                Field::new("region", DataType::Utf8, true),
354            ]))
355            .unwrap(),
356        )
357    }
358
359    fn test_metadata() -> LookupTableMetadata {
360        LookupTableMetadata {
361            connector: "postgres-cdc".to_string(),
362            strategy: "replicated".to_string(),
363            pushdown_mode: "auto".to_string(),
364            primary_key: vec!["id".to_string()],
365        }
366    }
367
368    fn test_node() -> LookupJoinNode {
369        let stream_schema = test_stream_schema();
370        let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
371            produce_one_row: false,
372            schema: stream_schema,
373        });
374
375        LookupJoinNode::new(
376            input,
377            "customers".to_string(),
378            test_lookup_schema(),
379            vec![JoinKeyPair {
380                stream_expr: col("customer_id"),
381                lookup_column: "id".to_string(),
382            }],
383            LookupJoinType::Inner,
384            vec![],
385            HashSet::from(["name".to_string(), "region".to_string()]),
386            test_output_schema(),
387            test_metadata(),
388        )
389    }
390
391    #[test]
392    fn test_name() {
393        let node = test_node();
394        assert_eq!(node.name(), "LookupJoin");
395    }
396
397    #[test]
398    fn test_inputs() {
399        let node = test_node();
400        assert_eq!(node.inputs().len(), 1);
401    }
402
403    #[test]
404    fn test_schema() {
405        let node = test_node();
406        assert_eq!(node.schema().fields().len(), 6);
407    }
408
409    #[test]
410    fn test_expressions() {
411        let node = test_node();
412        let exprs = node.expressions();
413        assert_eq!(exprs.len(), 1); // one join key, no pushdown predicates
414    }
415
416    #[test]
417    fn test_fmt_for_explain() {
418        let node = test_node();
419        let explain = format!("{node:?}");
420        assert!(explain.contains("LookupJoin"));
421
422        // Test the Display-like explain output
423        let mut buf = String::new();
424        write!(buf, "{}", DisplayExplain(&node)).unwrap();
425        assert!(buf.contains("LookupJoin: table=customers"));
426        assert!(buf.contains("type=Inner"));
427    }
428
429    #[test]
430    fn test_with_exprs_and_inputs_roundtrip() {
431        let node = test_node();
432        let exprs = node.expressions();
433        let inputs: Vec<LogicalPlan> = node.inputs().into_iter().cloned().collect();
434
435        let rebuilt = node.with_exprs_and_inputs(exprs, inputs).unwrap();
436        assert_eq!(rebuilt.lookup_table, "customers");
437        assert_eq!(rebuilt.join_keys.len(), 1);
438        assert_eq!(rebuilt.join_type, LookupJoinType::Inner);
439    }
440
441    #[test]
442    fn test_left_outer_join() {
443        let stream_schema = test_stream_schema();
444        let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
445            produce_one_row: false,
446            schema: stream_schema,
447        });
448
449        let node = LookupJoinNode::new(
450            input,
451            "customers".to_string(),
452            test_lookup_schema(),
453            vec![JoinKeyPair {
454                stream_expr: col("customer_id"),
455                lookup_column: "id".to_string(),
456            }],
457            LookupJoinType::LeftOuter,
458            vec![],
459            HashSet::new(),
460            test_output_schema(),
461            test_metadata(),
462        );
463
464        assert_eq!(node.join_type(), LookupJoinType::LeftOuter);
465    }
466
467    /// Helper to test `fmt_for_explain` through the trait method.
468    struct DisplayExplain<'a>(&'a LookupJoinNode);
469
470    impl fmt::Display for DisplayExplain<'_> {
471        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
472            UserDefinedLogicalNodeCore::fmt_for_explain(self.0, f)
473        }
474    }
475}