Skip to main content

laminar_sql/planner/
lookup_join.rs

1//! Optimizer rules that rewrite standard JOINs to `LookupJoinNode`.
2//!
3//! When a query joins a streaming source with a registered lookup table,
4//! the `LookupJoinRewriteRule` replaces the standard hash/merge join
5//! with a `LookupJoinNode` that uses the lookup source connector.
6
7#[allow(clippy::disallowed_types)] // cold path: query planning
8use std::collections::{HashMap, HashSet};
9use std::fmt;
10use std::sync::Arc;
11
12use datafusion::common::{DFSchema, Result};
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
15use datafusion_common::tree_node::Transformed;
16use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
17
18use crate::datafusion::lookup_join::{
19    JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
20};
21use crate::planner::LookupTableInfo;
22
23/// Rewrites standard JOIN nodes that reference a lookup table into
24/// `LookupJoinNode` extension nodes.
25#[derive(Debug)]
26pub struct LookupJoinRewriteRule {
27    /// Registered lookup tables, keyed by name.
28    lookup_tables: HashMap<String, LookupTableInfo>,
29}
30
31impl LookupJoinRewriteRule {
32    /// Creates a new rewrite rule with the given set of registered lookup tables.
33    #[must_use]
34    pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
35        Self { lookup_tables }
36    }
37
38    /// Detects which side of a join (if any) is a lookup table scan.
39    /// Returns `Some((lookup_side_is_right, table_name))`.
40    fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
41        // Check right side
42        if let Some(name) = scan_table_name(&join.right) {
43            if self.lookup_tables.contains_key(&name) {
44                return Some((true, name));
45            }
46        }
47        // Check left side
48        if let Some(name) = scan_table_name(&join.left) {
49            if self.lookup_tables.contains_key(&name) {
50                return Some((false, name));
51            }
52        }
53        None
54    }
55}
56
57impl OptimizerRule for LookupJoinRewriteRule {
58    fn name(&self) -> &'static str {
59        "lookup_join_rewrite"
60    }
61
62    fn apply_order(&self) -> Option<ApplyOrder> {
63        Some(ApplyOrder::BottomUp)
64    }
65
66    fn rewrite(
67        &self,
68        plan: LogicalPlan,
69        _config: &dyn OptimizerConfig,
70    ) -> Result<Transformed<LogicalPlan>> {
71        let LogicalPlan::Join(join) = &plan else {
72            return Ok(Transformed::no(plan));
73        };
74
75        let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
76            return Ok(Transformed::no(plan));
77        };
78
79        let info = &self.lookup_tables[&table_name];
80
81        // Determine which side is stream and which is lookup
82        let (stream_plan, lookup_plan) = if lookup_is_right {
83            (join.left.as_ref(), join.right.as_ref())
84        } else {
85            (join.right.as_ref(), join.left.as_ref())
86        };
87
88        // Extract aliases for qualified column resolution (C7)
89        let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
90        let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
91
92        let lookup_schema = lookup_plan.schema().clone();
93
94        // Build join key pairs from the equijoin conditions
95        let join_keys: Vec<JoinKeyPair> = join
96            .on
97            .iter()
98            .map(|(left_expr, right_expr)| {
99                if lookup_is_right {
100                    JoinKeyPair {
101                        stream_expr: left_expr.clone(),
102                        lookup_column: right_expr.to_string(),
103                    }
104                } else {
105                    JoinKeyPair {
106                        stream_expr: right_expr.clone(),
107                        lookup_column: left_expr.to_string(),
108                    }
109                }
110            })
111            .collect();
112
113        // Convert DataFusion join type to our lookup join type
114        let join_type = match join.join_type {
115            datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
116            datafusion::logical_expr::JoinType::Left if lookup_is_right => {
117                LookupJoinType::LeftOuter
118            }
119            datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
120                LookupJoinType::LeftOuter
121            }
122            _ => return Ok(Transformed::no(plan)),
123        };
124
125        // All lookup columns are required initially; pruning is done later
126        let required_columns: HashSet<String> = lookup_schema
127            .fields()
128            .iter()
129            .map(|f| f.name().clone())
130            .collect();
131
132        // Build output schema from stream + lookup
133        let stream_schema = stream_plan.schema();
134        let merged_fields: Vec<_> = stream_schema
135            .fields()
136            .iter()
137            .chain(lookup_schema.fields().iter())
138            .cloned()
139            .collect();
140        let output_schema = Arc::new(DFSchema::from_unqualified_fields(
141            merged_fields.into(),
142            HashMap::new(),
143        )?);
144
145        let metadata = LookupTableMetadata {
146            connector: info.properties.connector.to_string(),
147            strategy: info.properties.strategy.to_string(),
148            pushdown_mode: info.properties.pushdown_mode.to_string(),
149            primary_key: info.primary_key.clone(),
150        };
151
152        let node = LookupJoinNode::new(
153            stream_plan.clone(),
154            table_name,
155            lookup_schema,
156            join_keys,
157            join_type,
158            vec![], // predicates pushed down later
159            required_columns,
160            output_schema,
161            metadata,
162        )
163        .with_aliases(lookup_alias, stream_alias);
164
165        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
166            node: Arc::new(node),
167        })))
168    }
169}
170
171/// Column pruning rule for `LookupJoinNode`.
172///
173/// Narrows `required_lookup_columns` to only the columns referenced
174/// by downstream plan nodes.
175#[derive(Debug)]
176pub struct LookupColumnPruningRule;
177
178impl OptimizerRule for LookupColumnPruningRule {
179    fn name(&self) -> &'static str {
180        "lookup_column_pruning"
181    }
182
183    fn apply_order(&self) -> Option<ApplyOrder> {
184        Some(ApplyOrder::TopDown)
185    }
186
187    fn rewrite(
188        &self,
189        plan: LogicalPlan,
190        _config: &dyn OptimizerConfig,
191    ) -> Result<Transformed<LogicalPlan>> {
192        let LogicalPlan::Extension(ext) = &plan else {
193            return Ok(Transformed::no(plan));
194        };
195
196        let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
197            return Ok(Transformed::no(plan));
198        };
199
200        // Collect columns actually used downstream by walking the parent plan.
201        // For now, we use the node's schema to determine which lookup columns
202        // appear in the output. A full implementation would track column usage
203        // from parent nodes; this is a conservative starting point.
204        let schema = UserDefinedLogicalNodeCore::schema(node);
205        let used: HashSet<String> = schema
206            .fields()
207            .iter()
208            .filter(|f| node.required_lookup_columns().contains(f.name()))
209            .map(|f| f.name().clone())
210            .collect();
211
212        if used == *node.required_lookup_columns() {
213            return Ok(Transformed::no(plan));
214        }
215
216        // Rebuild with narrowed columns
217        let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
218        let pruned = LookupJoinNode::new(
219            node_inputs[0].clone(),
220            node.lookup_table_name().to_string(),
221            node.lookup_schema().clone(),
222            node.join_keys().to_vec(),
223            node.join_type(),
224            node.pushdown_predicates().to_vec(),
225            used,
226            schema.clone(),
227            node.metadata().clone(),
228        )
229        .with_local_predicates(node.local_predicates().to_vec())
230        .with_aliases(
231            node.lookup_alias().map(String::from),
232            node.stream_alias().map(String::from),
233        );
234
235        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
236            node: Arc::new(pruned),
237        })))
238    }
239}
240
241/// Extracts the table name and optional alias from a plan node.
242///
243/// Returns `(base_table_name, alias)` — alias is the `SubqueryAlias` name
244/// if the scan is wrapped in one, `None` otherwise.
245fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
246    match plan {
247        LogicalPlan::TableScan(TableScan { table_name, .. }) => {
248            Some((table_name.table().to_string(), None))
249        }
250        LogicalPlan::SubqueryAlias(alias) => {
251            let alias_name = alias.alias.table().to_string();
252            scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
253        }
254        _ => None,
255    }
256}
257
258/// Extracts the table name from a `TableScan` node, unwrapping aliases.
259fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
260    scan_table_name_and_alias(plan).map(|(name, _)| name)
261}
262
263/// Display helpers for connector/strategy/pushdown types.
264impl fmt::Display for crate::parser::lookup_table::ConnectorType {
265    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266        match self {
267            Self::Postgres => write!(f, "postgres"),
268            Self::PostgresCdc => write!(f, "postgres-cdc"),
269            Self::MysqlCdc => write!(f, "mysql-cdc"),
270            Self::Redis => write!(f, "redis"),
271            Self::S3Parquet => write!(f, "s3-parquet"),
272            Self::DeltaLake => write!(f, "delta-lake"),
273            Self::Static => write!(f, "static"),
274            Self::Custom(s) => write!(f, "{s}"),
275        }
276    }
277}
278
279impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281        match self {
282            Self::Replicated => write!(f, "replicated"),
283            Self::Partitioned => write!(f, "partitioned"),
284            Self::OnDemand => write!(f, "on-demand"),
285        }
286    }
287}
288
289impl fmt::Display for crate::parser::lookup_table::PushdownMode {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        match self {
292            Self::Auto => write!(f, "auto"),
293            Self::Enabled => write!(f, "enabled"),
294            Self::Disabled => write!(f, "disabled"),
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::datafusion::create_session_context;
303    use crate::parser::lookup_table::{
304        ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
305    };
306    use arrow::datatypes::{DataType, Field, Schema};
307    use datafusion::prelude::SessionContext;
308    use datafusion_common::tree_node::TreeNode;
309    use datafusion_optimizer::optimizer::OptimizerContext;
310
311    fn test_lookup_info() -> LookupTableInfo {
312        let arrow_schema = Arc::new(Schema::new(vec![
313            Field::new("id", DataType::Int32, false),
314            Field::new("name", DataType::Utf8, true),
315        ]));
316        LookupTableInfo {
317            name: "customers".to_string(),
318            columns: vec![
319                ("id".to_string(), "INT".to_string()),
320                ("name".to_string(), "VARCHAR".to_string()),
321            ],
322            primary_key: vec!["id".to_string()],
323            properties: LookupTableProperties {
324                connector: ConnectorType::PostgresCdc,
325                connection: Some("postgresql://localhost/db".to_string()),
326                strategy: LookupStrategy::Replicated,
327                cache_memory: Some(ByteSize(512 * 1024 * 1024)),
328                cache_disk: None,
329                cache_ttl: None,
330                pushdown_mode: PushdownMode::Auto,
331            },
332            arrow_schema,
333            #[allow(clippy::disallowed_types)] // cold path: query planning
334            raw_options: std::collections::HashMap::new(),
335        }
336    }
337
338    fn register_test_tables(ctx: &SessionContext) {
339        let orders_schema = Arc::new(Schema::new(vec![
340            Field::new("order_id", DataType::Int64, false),
341            Field::new("customer_id", DataType::Int64, false),
342            Field::new("amount", DataType::Float64, false),
343        ]));
344        let customers_schema = Arc::new(Schema::new(vec![
345            Field::new("id", DataType::Int64, false),
346            Field::new("name", DataType::Utf8, true),
347        ]));
348        ctx.register_batch(
349            "orders",
350            arrow::array::RecordBatch::new_empty(orders_schema),
351        )
352        .unwrap();
353        ctx.register_batch(
354            "customers",
355            arrow::array::RecordBatch::new_empty(customers_schema),
356        )
357        .unwrap();
358    }
359
360    #[tokio::test]
361    async fn test_rewrite_join_on_lookup_table() {
362        let ctx = create_session_context();
363        register_test_tables(&ctx);
364
365        let plan = ctx
366            .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
367            .await
368            .unwrap()
369            .into_unoptimized_plan();
370
371        let mut lookup_tables = HashMap::new();
372        lookup_tables.insert("customers".to_string(), test_lookup_info());
373        let rule = LookupJoinRewriteRule::new(lookup_tables);
374
375        let transformed = plan
376            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
377            .unwrap();
378
379        // Verify rewrite happened
380        assert!(transformed.transformed);
381        let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
382        assert!(has_lookup, "Expected LookupJoin in plan");
383    }
384
385    #[tokio::test]
386    async fn test_non_lookup_join_not_rewritten() {
387        let ctx = create_session_context();
388        // Register both as regular tables (neither is a lookup table)
389        let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
390        let schema_b = Arc::new(Schema::new(vec![Field::new(
391            "a_id",
392            DataType::Int64,
393            false,
394        )]));
395        ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
396            .unwrap();
397        ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
398            .unwrap();
399
400        let plan = ctx
401            .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
402            .await
403            .unwrap()
404            .into_unoptimized_plan();
405
406        // No lookup tables registered
407        let rule = LookupJoinRewriteRule::new(HashMap::new());
408
409        let transformed = plan
410            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
411            .unwrap();
412
413        assert!(!transformed.transformed);
414    }
415
416    #[tokio::test]
417    async fn test_left_outer_produces_left_outer_type() {
418        let ctx = create_session_context();
419        register_test_tables(&ctx);
420
421        let plan = ctx
422            .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
423            .await
424            .unwrap()
425            .into_unoptimized_plan();
426
427        let mut lookup_tables = HashMap::new();
428        lookup_tables.insert("customers".to_string(), test_lookup_info());
429        let rule = LookupJoinRewriteRule::new(lookup_tables);
430
431        let transformed = plan
432            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
433            .unwrap();
434
435        assert!(transformed.transformed);
436        let debug_str = format!("{:?}", transformed.data);
437        assert!(
438            debug_str.contains("LeftOuter"),
439            "Expected LeftOuter join type, got: {debug_str}"
440        );
441    }
442
443    #[test]
444    fn test_fmt_display_connector_type() {
445        assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
446        assert_eq!(ConnectorType::Redis.to_string(), "redis");
447        assert_eq!(
448            ConnectorType::Custom("my-conn".into()).to_string(),
449            "my-conn"
450        );
451    }
452
453    #[test]
454    fn test_fmt_display_strategy() {
455        assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
456        assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
457    }
458
459    #[test]
460    fn test_fmt_display_pushdown_mode() {
461        assert_eq!(PushdownMode::Auto.to_string(), "auto");
462        assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
463    }
464}