Skip to main content

laminar_sql/planner/
lookup_join.rs

1//! Optimizer rules that rewrite standard JOINs to `LookupJoinNode`.
2//!
3//! When a query joins a streaming source with a registered lookup table,
4//! the `LookupJoinRewriteRule` replaces the standard hash/merge join
5//! with a `LookupJoinNode` that uses the lookup source connector.
6
7use std::collections::{HashMap, HashSet};
8use std::fmt;
9use std::sync::Arc;
10
11use datafusion::common::{DFSchema, Result};
12use datafusion::logical_expr::logical_plan::LogicalPlan;
13use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
14use datafusion_common::tree_node::Transformed;
15use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
16
17use crate::datafusion::lookup_join::{
18    JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
19};
20use crate::planner::LookupTableInfo;
21
22/// Rewrites standard JOIN nodes that reference a lookup table into
23/// `LookupJoinNode` extension nodes.
24#[derive(Debug)]
25pub struct LookupJoinRewriteRule {
26    /// Registered lookup tables, keyed by name.
27    lookup_tables: HashMap<String, LookupTableInfo>,
28}
29
30impl LookupJoinRewriteRule {
31    /// Creates a new rewrite rule with the given set of registered lookup tables.
32    #[must_use]
33    pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
34        Self { lookup_tables }
35    }
36
37    /// Detects which side of a join (if any) is a lookup table scan.
38    /// Returns `Some((lookup_side_is_right, table_name))`.
39    fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
40        // Check right side
41        if let Some(name) = scan_table_name(&join.right) {
42            if self.lookup_tables.contains_key(&name) {
43                return Some((true, name));
44            }
45        }
46        // Check left side
47        if let Some(name) = scan_table_name(&join.left) {
48            if self.lookup_tables.contains_key(&name) {
49                return Some((false, name));
50            }
51        }
52        None
53    }
54}
55
56impl OptimizerRule for LookupJoinRewriteRule {
57    fn name(&self) -> &'static str {
58        "lookup_join_rewrite"
59    }
60
61    fn apply_order(&self) -> Option<ApplyOrder> {
62        Some(ApplyOrder::BottomUp)
63    }
64
65    fn rewrite(
66        &self,
67        plan: LogicalPlan,
68        _config: &dyn OptimizerConfig,
69    ) -> Result<Transformed<LogicalPlan>> {
70        let LogicalPlan::Join(join) = &plan else {
71            return Ok(Transformed::no(plan));
72        };
73
74        let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
75            return Ok(Transformed::no(plan));
76        };
77
78        let info = &self.lookup_tables[&table_name];
79
80        // Determine which side is stream and which is lookup
81        let (stream_plan, lookup_plan) = if lookup_is_right {
82            (join.left.as_ref(), join.right.as_ref())
83        } else {
84            (join.right.as_ref(), join.left.as_ref())
85        };
86
87        // Extract aliases for qualified column resolution (C7)
88        let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
89        let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
90
91        let lookup_schema = lookup_plan.schema().clone();
92
93        // Build join key pairs from the equijoin conditions
94        let join_keys: Vec<JoinKeyPair> = join
95            .on
96            .iter()
97            .map(|(left_expr, right_expr)| {
98                if lookup_is_right {
99                    JoinKeyPair {
100                        stream_expr: left_expr.clone(),
101                        lookup_column: right_expr.to_string(),
102                    }
103                } else {
104                    JoinKeyPair {
105                        stream_expr: right_expr.clone(),
106                        lookup_column: left_expr.to_string(),
107                    }
108                }
109            })
110            .collect();
111
112        // Convert DataFusion join type to our lookup join type
113        let join_type = match join.join_type {
114            datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
115            datafusion::logical_expr::JoinType::Left if lookup_is_right => {
116                LookupJoinType::LeftOuter
117            }
118            datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
119                LookupJoinType::LeftOuter
120            }
121            _ => return Ok(Transformed::no(plan)),
122        };
123
124        // All lookup columns are required initially; pruning is done later
125        let required_columns: HashSet<String> = lookup_schema
126            .fields()
127            .iter()
128            .map(|f| f.name().clone())
129            .collect();
130
131        // Build output schema from stream + lookup
132        let stream_schema = stream_plan.schema();
133        let merged_fields: Vec<_> = stream_schema
134            .fields()
135            .iter()
136            .chain(lookup_schema.fields().iter())
137            .cloned()
138            .collect();
139        let output_schema = Arc::new(DFSchema::from_unqualified_fields(
140            merged_fields.into(),
141            HashMap::new(),
142        )?);
143
144        let metadata = LookupTableMetadata {
145            connector: info.properties.connector.to_string(),
146            strategy: info.properties.strategy.to_string(),
147            pushdown_mode: info.properties.pushdown_mode.to_string(),
148            primary_key: info.primary_key.clone(),
149        };
150
151        let node = LookupJoinNode::new(
152            stream_plan.clone(),
153            table_name,
154            lookup_schema,
155            join_keys,
156            join_type,
157            vec![], // predicates pushed down later
158            required_columns,
159            output_schema,
160            metadata,
161        )
162        .with_aliases(lookup_alias, stream_alias);
163
164        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
165            node: Arc::new(node),
166        })))
167    }
168}
169
170/// Column pruning rule for `LookupJoinNode`.
171///
172/// Narrows `required_lookup_columns` to only the columns referenced
173/// by downstream plan nodes.
174#[derive(Debug)]
175pub struct LookupColumnPruningRule;
176
177impl OptimizerRule for LookupColumnPruningRule {
178    fn name(&self) -> &'static str {
179        "lookup_column_pruning"
180    }
181
182    fn apply_order(&self) -> Option<ApplyOrder> {
183        Some(ApplyOrder::TopDown)
184    }
185
186    fn rewrite(
187        &self,
188        plan: LogicalPlan,
189        _config: &dyn OptimizerConfig,
190    ) -> Result<Transformed<LogicalPlan>> {
191        let LogicalPlan::Extension(ext) = &plan else {
192            return Ok(Transformed::no(plan));
193        };
194
195        let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
196            return Ok(Transformed::no(plan));
197        };
198
199        // Collect columns actually used downstream by walking the parent plan.
200        // For now, we use the node's schema to determine which lookup columns
201        // appear in the output. A full implementation would track column usage
202        // from parent nodes; this is a conservative starting point.
203        let schema = UserDefinedLogicalNodeCore::schema(node);
204        let used: HashSet<String> = schema
205            .fields()
206            .iter()
207            .filter(|f| node.required_lookup_columns().contains(f.name()))
208            .map(|f| f.name().clone())
209            .collect();
210
211        if used == *node.required_lookup_columns() {
212            return Ok(Transformed::no(plan));
213        }
214
215        // Rebuild with narrowed columns
216        let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
217        let pruned = LookupJoinNode::new(
218            node_inputs[0].clone(),
219            node.lookup_table_name().to_string(),
220            node.lookup_schema().clone(),
221            node.join_keys().to_vec(),
222            node.join_type(),
223            node.pushdown_predicates().to_vec(),
224            used,
225            schema.clone(),
226            node.metadata().clone(),
227        )
228        .with_local_predicates(node.local_predicates().to_vec())
229        .with_aliases(
230            node.lookup_alias().map(String::from),
231            node.stream_alias().map(String::from),
232        );
233
234        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
235            node: Arc::new(pruned),
236        })))
237    }
238}
239
240/// Extracts the table name and optional alias from a plan node.
241///
242/// Returns `(base_table_name, alias)` — alias is the `SubqueryAlias` name
243/// if the scan is wrapped in one, `None` otherwise.
244fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
245    match plan {
246        LogicalPlan::TableScan(TableScan { table_name, .. }) => {
247            Some((table_name.table().to_string(), None))
248        }
249        LogicalPlan::SubqueryAlias(alias) => {
250            let alias_name = alias.alias.table().to_string();
251            scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
252        }
253        _ => None,
254    }
255}
256
257/// Extracts the table name from a `TableScan` node, unwrapping aliases.
258fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
259    scan_table_name_and_alias(plan).map(|(name, _)| name)
260}
261
262/// Display helpers for connector/strategy/pushdown types.
263impl fmt::Display for crate::parser::lookup_table::ConnectorType {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        match self {
266            Self::PostgresCdc => write!(f, "postgres-cdc"),
267            Self::MysqlCdc => write!(f, "mysql-cdc"),
268            Self::Redis => write!(f, "redis"),
269            Self::S3Parquet => write!(f, "s3-parquet"),
270            Self::Static => write!(f, "static"),
271            Self::Custom(s) => write!(f, "{s}"),
272        }
273    }
274}
275
276impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        match self {
279            Self::Replicated => write!(f, "replicated"),
280            Self::Partitioned => write!(f, "partitioned"),
281            Self::OnDemand => write!(f, "on-demand"),
282        }
283    }
284}
285
286impl fmt::Display for crate::parser::lookup_table::PushdownMode {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        match self {
289            Self::Auto => write!(f, "auto"),
290            Self::Enabled => write!(f, "enabled"),
291            Self::Disabled => write!(f, "disabled"),
292        }
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::parser::lookup_table::{
300        ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
301    };
302    use arrow::datatypes::{DataType, Field, Schema};
303    use datafusion::prelude::SessionContext;
304    use datafusion_common::tree_node::TreeNode;
305    use datafusion_optimizer::optimizer::OptimizerContext;
306
307    fn test_lookup_info() -> LookupTableInfo {
308        LookupTableInfo {
309            name: "customers".to_string(),
310            columns: vec![
311                ("id".to_string(), "INT".to_string()),
312                ("name".to_string(), "VARCHAR".to_string()),
313            ],
314            primary_key: vec!["id".to_string()],
315            properties: LookupTableProperties {
316                connector: ConnectorType::PostgresCdc,
317                connection: Some("postgresql://localhost/db".to_string()),
318                strategy: LookupStrategy::Replicated,
319                cache_memory: Some(ByteSize(512 * 1024 * 1024)),
320                cache_disk: None,
321                cache_ttl: None,
322                pushdown_mode: PushdownMode::Auto,
323            },
324        }
325    }
326
327    fn register_test_tables(ctx: &SessionContext) {
328        let orders_schema = Arc::new(Schema::new(vec![
329            Field::new("order_id", DataType::Int64, false),
330            Field::new("customer_id", DataType::Int64, false),
331            Field::new("amount", DataType::Float64, false),
332        ]));
333        let customers_schema = Arc::new(Schema::new(vec![
334            Field::new("id", DataType::Int64, false),
335            Field::new("name", DataType::Utf8, true),
336        ]));
337        ctx.register_batch(
338            "orders",
339            arrow::array::RecordBatch::new_empty(orders_schema),
340        )
341        .unwrap();
342        ctx.register_batch(
343            "customers",
344            arrow::array::RecordBatch::new_empty(customers_schema),
345        )
346        .unwrap();
347    }
348
349    #[tokio::test]
350    async fn test_rewrite_join_on_lookup_table() {
351        let ctx = SessionContext::new();
352        register_test_tables(&ctx);
353
354        let plan = ctx
355            .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
356            .await
357            .unwrap()
358            .into_unoptimized_plan();
359
360        let mut lookup_tables = HashMap::new();
361        lookup_tables.insert("customers".to_string(), test_lookup_info());
362        let rule = LookupJoinRewriteRule::new(lookup_tables);
363
364        let transformed = plan
365            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
366            .unwrap();
367
368        // Verify rewrite happened
369        assert!(transformed.transformed);
370        let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
371        assert!(has_lookup, "Expected LookupJoin in plan");
372    }
373
374    #[tokio::test]
375    async fn test_non_lookup_join_not_rewritten() {
376        let ctx = SessionContext::new();
377        // Register both as regular tables (neither is a lookup table)
378        let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
379        let schema_b = Arc::new(Schema::new(vec![Field::new(
380            "a_id",
381            DataType::Int64,
382            false,
383        )]));
384        ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
385            .unwrap();
386        ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
387            .unwrap();
388
389        let plan = ctx
390            .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
391            .await
392            .unwrap()
393            .into_unoptimized_plan();
394
395        // No lookup tables registered
396        let rule = LookupJoinRewriteRule::new(HashMap::new());
397
398        let transformed = plan
399            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
400            .unwrap();
401
402        assert!(!transformed.transformed);
403    }
404
405    #[tokio::test]
406    async fn test_left_outer_produces_left_outer_type() {
407        let ctx = SessionContext::new();
408        register_test_tables(&ctx);
409
410        let plan = ctx
411            .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
412            .await
413            .unwrap()
414            .into_unoptimized_plan();
415
416        let mut lookup_tables = HashMap::new();
417        lookup_tables.insert("customers".to_string(), test_lookup_info());
418        let rule = LookupJoinRewriteRule::new(lookup_tables);
419
420        let transformed = plan
421            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
422            .unwrap();
423
424        assert!(transformed.transformed);
425        let debug_str = format!("{:?}", transformed.data);
426        assert!(
427            debug_str.contains("LeftOuter"),
428            "Expected LeftOuter join type, got: {debug_str}"
429        );
430    }
431
432    #[test]
433    fn test_fmt_display_connector_type() {
434        assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
435        assert_eq!(ConnectorType::Redis.to_string(), "redis");
436        assert_eq!(
437            ConnectorType::Custom("my-conn".into()).to_string(),
438            "my-conn"
439        );
440    }
441
442    #[test]
443    fn test_fmt_display_strategy() {
444        assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
445        assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
446    }
447
448    #[test]
449    fn test_fmt_display_pushdown_mode() {
450        assert_eq!(PushdownMode::Auto.to_string(), "auto");
451        assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
452    }
453}