Skip to main content

laminar_sql/planner/
lookup_join.rs

1//! Optimizer rules that rewrite standard JOINs to `LookupJoinNode`.
2//!
3//! When a query joins a streaming source with a registered lookup table,
4//! the `LookupJoinRewriteRule` replaces the standard hash/merge join
5//! with a `LookupJoinNode` that uses the lookup source connector.
6
7use std::collections::{HashMap, HashSet};
8use std::fmt;
9use std::sync::Arc;
10
11use datafusion::common::{DFSchema, Result};
12use datafusion::logical_expr::logical_plan::LogicalPlan;
13use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
14use datafusion_common::tree_node::Transformed;
15use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
16
17use crate::datafusion::lookup_join::{
18    JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
19};
20use crate::planner::LookupTableInfo;
21
22/// Rewrites standard JOIN nodes that reference a lookup table into
23/// `LookupJoinNode` extension nodes.
24#[derive(Debug)]
25pub struct LookupJoinRewriteRule {
26    /// Registered lookup tables, keyed by name.
27    lookup_tables: HashMap<String, LookupTableInfo>,
28}
29
30impl LookupJoinRewriteRule {
31    /// Creates a new rewrite rule with the given set of registered lookup tables.
32    #[must_use]
33    pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
34        Self { lookup_tables }
35    }
36
37    /// Detects which side of a join (if any) is a lookup table scan.
38    /// Returns `Some((lookup_side_is_right, table_name))`.
39    fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
40        // Check right side
41        if let Some(name) = scan_table_name(&join.right) {
42            if self.lookup_tables.contains_key(&name) {
43                return Some((true, name));
44            }
45        }
46        // Check left side
47        if let Some(name) = scan_table_name(&join.left) {
48            if self.lookup_tables.contains_key(&name) {
49                return Some((false, name));
50            }
51        }
52        None
53    }
54}
55
56impl OptimizerRule for LookupJoinRewriteRule {
57    fn name(&self) -> &'static str {
58        "lookup_join_rewrite"
59    }
60
61    fn apply_order(&self) -> Option<ApplyOrder> {
62        Some(ApplyOrder::BottomUp)
63    }
64
65    fn rewrite(
66        &self,
67        plan: LogicalPlan,
68        _config: &dyn OptimizerConfig,
69    ) -> Result<Transformed<LogicalPlan>> {
70        let LogicalPlan::Join(join) = &plan else {
71            return Ok(Transformed::no(plan));
72        };
73
74        let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
75            return Ok(Transformed::no(plan));
76        };
77
78        let info = &self.lookup_tables[&table_name];
79
80        // Determine which side is stream and which is lookup
81        let (stream_plan, lookup_plan) = if lookup_is_right {
82            (join.left.as_ref(), join.right.as_ref())
83        } else {
84            (join.right.as_ref(), join.left.as_ref())
85        };
86
87        // Extract aliases for qualified column resolution (C7)
88        let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
89        let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
90
91        let lookup_schema = lookup_plan.schema().clone();
92
93        // Build join key pairs from the equijoin conditions
94        let join_keys: Vec<JoinKeyPair> = join
95            .on
96            .iter()
97            .map(|(left_expr, right_expr)| {
98                if lookup_is_right {
99                    JoinKeyPair {
100                        stream_expr: left_expr.clone(),
101                        lookup_column: right_expr.to_string(),
102                    }
103                } else {
104                    JoinKeyPair {
105                        stream_expr: right_expr.clone(),
106                        lookup_column: left_expr.to_string(),
107                    }
108                }
109            })
110            .collect();
111
112        // Convert DataFusion join type to our lookup join type
113        let join_type = match join.join_type {
114            datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
115            datafusion::logical_expr::JoinType::Left if lookup_is_right => {
116                LookupJoinType::LeftOuter
117            }
118            datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
119                LookupJoinType::LeftOuter
120            }
121            _ => return Ok(Transformed::no(plan)),
122        };
123
124        // All lookup columns are required initially; pruning is done later
125        let required_columns: HashSet<String> = lookup_schema
126            .fields()
127            .iter()
128            .map(|f| f.name().clone())
129            .collect();
130
131        // Build output schema from stream + lookup
132        let stream_schema = stream_plan.schema();
133        let merged_fields: Vec<_> = stream_schema
134            .fields()
135            .iter()
136            .chain(lookup_schema.fields().iter())
137            .cloned()
138            .collect();
139        let output_schema = Arc::new(DFSchema::from_unqualified_fields(
140            merged_fields.into(),
141            HashMap::new(),
142        )?);
143
144        let metadata = LookupTableMetadata {
145            connector: info.properties.connector.to_string(),
146            strategy: info.properties.strategy.to_string(),
147            pushdown_mode: info.properties.pushdown_mode.to_string(),
148            primary_key: info.primary_key.clone(),
149        };
150
151        let node = LookupJoinNode::new(
152            stream_plan.clone(),
153            table_name,
154            lookup_schema,
155            join_keys,
156            join_type,
157            vec![], // predicates pushed down later
158            required_columns,
159            output_schema,
160            metadata,
161        )
162        .with_aliases(lookup_alias, stream_alias);
163
164        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
165            node: Arc::new(node),
166        })))
167    }
168}
169
170/// Column pruning rule for `LookupJoinNode`.
171///
172/// Narrows `required_lookup_columns` to only the columns referenced
173/// by downstream plan nodes.
174#[derive(Debug)]
175pub struct LookupColumnPruningRule;
176
177impl OptimizerRule for LookupColumnPruningRule {
178    fn name(&self) -> &'static str {
179        "lookup_column_pruning"
180    }
181
182    fn apply_order(&self) -> Option<ApplyOrder> {
183        Some(ApplyOrder::TopDown)
184    }
185
186    fn rewrite(
187        &self,
188        plan: LogicalPlan,
189        _config: &dyn OptimizerConfig,
190    ) -> Result<Transformed<LogicalPlan>> {
191        let LogicalPlan::Extension(ext) = &plan else {
192            return Ok(Transformed::no(plan));
193        };
194
195        let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
196            return Ok(Transformed::no(plan));
197        };
198
199        // Collect columns actually used downstream by walking the parent plan.
200        // For now, we use the node's schema to determine which lookup columns
201        // appear in the output. A full implementation would track column usage
202        // from parent nodes; this is a conservative starting point.
203        let schema = UserDefinedLogicalNodeCore::schema(node);
204        let used: HashSet<String> = schema
205            .fields()
206            .iter()
207            .filter(|f| node.required_lookup_columns().contains(f.name()))
208            .map(|f| f.name().clone())
209            .collect();
210
211        if used == *node.required_lookup_columns() {
212            return Ok(Transformed::no(plan));
213        }
214
215        // Rebuild with narrowed columns
216        let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
217        let pruned = LookupJoinNode::new(
218            node_inputs[0].clone(),
219            node.lookup_table_name().to_string(),
220            node.lookup_schema().clone(),
221            node.join_keys().to_vec(),
222            node.join_type(),
223            node.pushdown_predicates().to_vec(),
224            used,
225            schema.clone(),
226            node.metadata().clone(),
227        )
228        .with_local_predicates(node.local_predicates().to_vec())
229        .with_aliases(
230            node.lookup_alias().map(String::from),
231            node.stream_alias().map(String::from),
232        );
233
234        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
235            node: Arc::new(pruned),
236        })))
237    }
238}
239
240/// Extracts the table name and optional alias from a plan node.
241///
242/// Returns `(base_table_name, alias)` — alias is the `SubqueryAlias` name
243/// if the scan is wrapped in one, `None` otherwise.
244fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
245    match plan {
246        LogicalPlan::TableScan(TableScan { table_name, .. }) => {
247            Some((table_name.table().to_string(), None))
248        }
249        LogicalPlan::SubqueryAlias(alias) => {
250            let alias_name = alias.alias.table().to_string();
251            scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
252        }
253        _ => None,
254    }
255}
256
257/// Extracts the table name from a `TableScan` node, unwrapping aliases.
258fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
259    scan_table_name_and_alias(plan).map(|(name, _)| name)
260}
261
262/// Display helpers for connector/strategy/pushdown types.
263impl fmt::Display for crate::parser::lookup_table::ConnectorType {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        match self {
266            Self::PostgresCdc => write!(f, "postgres-cdc"),
267            Self::MysqlCdc => write!(f, "mysql-cdc"),
268            Self::Redis => write!(f, "redis"),
269            Self::S3Parquet => write!(f, "s3-parquet"),
270            Self::Static => write!(f, "static"),
271            Self::Custom(s) => write!(f, "{s}"),
272        }
273    }
274}
275
276impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        match self {
279            Self::Replicated => write!(f, "replicated"),
280            Self::Partitioned => write!(f, "partitioned"),
281            Self::OnDemand => write!(f, "on-demand"),
282        }
283    }
284}
285
286impl fmt::Display for crate::parser::lookup_table::PushdownMode {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        match self {
289            Self::Auto => write!(f, "auto"),
290            Self::Enabled => write!(f, "enabled"),
291            Self::Disabled => write!(f, "disabled"),
292        }
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::datafusion::create_session_context;
300    use crate::parser::lookup_table::{
301        ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
302    };
303    use arrow::datatypes::{DataType, Field, Schema};
304    use datafusion::prelude::SessionContext;
305    use datafusion_common::tree_node::TreeNode;
306    use datafusion_optimizer::optimizer::OptimizerContext;
307
308    fn test_lookup_info() -> LookupTableInfo {
309        LookupTableInfo {
310            name: "customers".to_string(),
311            columns: vec![
312                ("id".to_string(), "INT".to_string()),
313                ("name".to_string(), "VARCHAR".to_string()),
314            ],
315            primary_key: vec!["id".to_string()],
316            properties: LookupTableProperties {
317                connector: ConnectorType::PostgresCdc,
318                connection: Some("postgresql://localhost/db".to_string()),
319                strategy: LookupStrategy::Replicated,
320                cache_memory: Some(ByteSize(512 * 1024 * 1024)),
321                cache_disk: None,
322                cache_ttl: None,
323                pushdown_mode: PushdownMode::Auto,
324            },
325        }
326    }
327
328    fn register_test_tables(ctx: &SessionContext) {
329        let orders_schema = Arc::new(Schema::new(vec![
330            Field::new("order_id", DataType::Int64, false),
331            Field::new("customer_id", DataType::Int64, false),
332            Field::new("amount", DataType::Float64, false),
333        ]));
334        let customers_schema = Arc::new(Schema::new(vec![
335            Field::new("id", DataType::Int64, false),
336            Field::new("name", DataType::Utf8, true),
337        ]));
338        ctx.register_batch(
339            "orders",
340            arrow::array::RecordBatch::new_empty(orders_schema),
341        )
342        .unwrap();
343        ctx.register_batch(
344            "customers",
345            arrow::array::RecordBatch::new_empty(customers_schema),
346        )
347        .unwrap();
348    }
349
350    #[tokio::test]
351    async fn test_rewrite_join_on_lookup_table() {
352        let ctx = create_session_context();
353        register_test_tables(&ctx);
354
355        let plan = ctx
356            .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
357            .await
358            .unwrap()
359            .into_unoptimized_plan();
360
361        let mut lookup_tables = HashMap::new();
362        lookup_tables.insert("customers".to_string(), test_lookup_info());
363        let rule = LookupJoinRewriteRule::new(lookup_tables);
364
365        let transformed = plan
366            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
367            .unwrap();
368
369        // Verify rewrite happened
370        assert!(transformed.transformed);
371        let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
372        assert!(has_lookup, "Expected LookupJoin in plan");
373    }
374
375    #[tokio::test]
376    async fn test_non_lookup_join_not_rewritten() {
377        let ctx = create_session_context();
378        // Register both as regular tables (neither is a lookup table)
379        let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
380        let schema_b = Arc::new(Schema::new(vec![Field::new(
381            "a_id",
382            DataType::Int64,
383            false,
384        )]));
385        ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
386            .unwrap();
387        ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
388            .unwrap();
389
390        let plan = ctx
391            .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
392            .await
393            .unwrap()
394            .into_unoptimized_plan();
395
396        // No lookup tables registered
397        let rule = LookupJoinRewriteRule::new(HashMap::new());
398
399        let transformed = plan
400            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
401            .unwrap();
402
403        assert!(!transformed.transformed);
404    }
405
406    #[tokio::test]
407    async fn test_left_outer_produces_left_outer_type() {
408        let ctx = create_session_context();
409        register_test_tables(&ctx);
410
411        let plan = ctx
412            .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
413            .await
414            .unwrap()
415            .into_unoptimized_plan();
416
417        let mut lookup_tables = HashMap::new();
418        lookup_tables.insert("customers".to_string(), test_lookup_info());
419        let rule = LookupJoinRewriteRule::new(lookup_tables);
420
421        let transformed = plan
422            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
423            .unwrap();
424
425        assert!(transformed.transformed);
426        let debug_str = format!("{:?}", transformed.data);
427        assert!(
428            debug_str.contains("LeftOuter"),
429            "Expected LeftOuter join type, got: {debug_str}"
430        );
431    }
432
433    #[test]
434    fn test_fmt_display_connector_type() {
435        assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
436        assert_eq!(ConnectorType::Redis.to_string(), "redis");
437        assert_eq!(
438            ConnectorType::Custom("my-conn".into()).to_string(),
439            "my-conn"
440        );
441    }
442
443    #[test]
444    fn test_fmt_display_strategy() {
445        assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
446        assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
447    }
448
449    #[test]
450    fn test_fmt_display_pushdown_mode() {
451        assert_eq!(PushdownMode::Auto.to_string(), "auto");
452        assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
453    }
454}