1#[allow(clippy::disallowed_types)] use std::collections::{HashMap, HashSet};
9use std::fmt;
10use std::sync::Arc;
11
12use datafusion::common::{DFSchema, Result};
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
15use datafusion_common::tree_node::Transformed;
16use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
17
18use crate::datafusion::lookup_join::{
19 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
20};
21use crate::planner::LookupTableInfo;
22
23#[derive(Debug)]
26pub struct LookupJoinRewriteRule {
27 lookup_tables: HashMap<String, LookupTableInfo>,
29}
30
31impl LookupJoinRewriteRule {
32 #[must_use]
34 pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
35 Self { lookup_tables }
36 }
37
38 fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
41 if let Some(name) = scan_table_name(&join.right) {
43 if self.lookup_tables.contains_key(&name) {
44 return Some((true, name));
45 }
46 }
47 if let Some(name) = scan_table_name(&join.left) {
49 if self.lookup_tables.contains_key(&name) {
50 return Some((false, name));
51 }
52 }
53 None
54 }
55}
56
57impl OptimizerRule for LookupJoinRewriteRule {
58 fn name(&self) -> &'static str {
59 "lookup_join_rewrite"
60 }
61
62 fn apply_order(&self) -> Option<ApplyOrder> {
63 Some(ApplyOrder::BottomUp)
64 }
65
66 fn rewrite(
67 &self,
68 plan: LogicalPlan,
69 _config: &dyn OptimizerConfig,
70 ) -> Result<Transformed<LogicalPlan>> {
71 let LogicalPlan::Join(join) = &plan else {
72 return Ok(Transformed::no(plan));
73 };
74
75 let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
76 return Ok(Transformed::no(plan));
77 };
78
79 let info = &self.lookup_tables[&table_name];
80
81 let (stream_plan, lookup_plan) = if lookup_is_right {
83 (join.left.as_ref(), join.right.as_ref())
84 } else {
85 (join.right.as_ref(), join.left.as_ref())
86 };
87
88 let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
90 let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
91
92 let lookup_schema = lookup_plan.schema().clone();
93
94 let join_keys: Vec<JoinKeyPair> = join
96 .on
97 .iter()
98 .map(|(left_expr, right_expr)| {
99 if lookup_is_right {
100 JoinKeyPair {
101 stream_expr: left_expr.clone(),
102 lookup_column: right_expr.to_string(),
103 }
104 } else {
105 JoinKeyPair {
106 stream_expr: right_expr.clone(),
107 lookup_column: left_expr.to_string(),
108 }
109 }
110 })
111 .collect();
112
113 let join_type = match join.join_type {
115 datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
116 datafusion::logical_expr::JoinType::Left if lookup_is_right => {
117 LookupJoinType::LeftOuter
118 }
119 datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
120 LookupJoinType::LeftOuter
121 }
122 _ => return Ok(Transformed::no(plan)),
123 };
124
125 let required_columns: HashSet<String> = lookup_schema
127 .fields()
128 .iter()
129 .map(|f| f.name().clone())
130 .collect();
131
132 let stream_schema = stream_plan.schema();
134 let merged_fields: Vec<_> = stream_schema
135 .fields()
136 .iter()
137 .chain(lookup_schema.fields().iter())
138 .cloned()
139 .collect();
140 let output_schema = Arc::new(DFSchema::from_unqualified_fields(
141 merged_fields.into(),
142 HashMap::new(),
143 )?);
144
145 let metadata = LookupTableMetadata {
146 connector: info.properties.connector.to_string(),
147 strategy: info.properties.strategy.to_string(),
148 pushdown_mode: info.properties.pushdown_mode.to_string(),
149 primary_key: info.primary_key.clone(),
150 };
151
152 let node = LookupJoinNode::new(
153 stream_plan.clone(),
154 table_name,
155 lookup_schema,
156 join_keys,
157 join_type,
158 vec![], required_columns,
160 output_schema,
161 metadata,
162 )
163 .with_aliases(lookup_alias, stream_alias);
164
165 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
166 node: Arc::new(node),
167 })))
168 }
169}
170
171#[derive(Debug)]
176pub struct LookupColumnPruningRule;
177
178impl OptimizerRule for LookupColumnPruningRule {
179 fn name(&self) -> &'static str {
180 "lookup_column_pruning"
181 }
182
183 fn apply_order(&self) -> Option<ApplyOrder> {
184 Some(ApplyOrder::TopDown)
185 }
186
187 fn rewrite(
188 &self,
189 plan: LogicalPlan,
190 _config: &dyn OptimizerConfig,
191 ) -> Result<Transformed<LogicalPlan>> {
192 let LogicalPlan::Extension(ext) = &plan else {
193 return Ok(Transformed::no(plan));
194 };
195
196 let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
197 return Ok(Transformed::no(plan));
198 };
199
200 let schema = UserDefinedLogicalNodeCore::schema(node);
205 let used: HashSet<String> = schema
206 .fields()
207 .iter()
208 .filter(|f| node.required_lookup_columns().contains(f.name()))
209 .map(|f| f.name().clone())
210 .collect();
211
212 if used == *node.required_lookup_columns() {
213 return Ok(Transformed::no(plan));
214 }
215
216 let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
218 let pruned = LookupJoinNode::new(
219 node_inputs[0].clone(),
220 node.lookup_table_name().to_string(),
221 node.lookup_schema().clone(),
222 node.join_keys().to_vec(),
223 node.join_type(),
224 node.pushdown_predicates().to_vec(),
225 used,
226 schema.clone(),
227 node.metadata().clone(),
228 )
229 .with_local_predicates(node.local_predicates().to_vec())
230 .with_aliases(
231 node.lookup_alias().map(String::from),
232 node.stream_alias().map(String::from),
233 );
234
235 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
236 node: Arc::new(pruned),
237 })))
238 }
239}
240
241fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
246 match plan {
247 LogicalPlan::TableScan(TableScan { table_name, .. }) => {
248 Some((table_name.table().to_string(), None))
249 }
250 LogicalPlan::SubqueryAlias(alias) => {
251 let alias_name = alias.alias.table().to_string();
252 scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
253 }
254 _ => None,
255 }
256}
257
258fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
260 scan_table_name_and_alias(plan).map(|(name, _)| name)
261}
262
263impl fmt::Display for crate::parser::lookup_table::ConnectorType {
265 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266 match self {
267 Self::Postgres => write!(f, "postgres"),
268 Self::PostgresCdc => write!(f, "postgres-cdc"),
269 Self::MysqlCdc => write!(f, "mysql-cdc"),
270 Self::Redis => write!(f, "redis"),
271 Self::S3Parquet => write!(f, "s3-parquet"),
272 Self::DeltaLake => write!(f, "delta-lake"),
273 Self::Static => write!(f, "static"),
274 Self::Custom(s) => write!(f, "{s}"),
275 }
276 }
277}
278
279impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
280 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281 match self {
282 Self::Replicated => write!(f, "replicated"),
283 Self::Partitioned => write!(f, "partitioned"),
284 Self::OnDemand => write!(f, "on-demand"),
285 }
286 }
287}
288
289impl fmt::Display for crate::parser::lookup_table::PushdownMode {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 match self {
292 Self::Auto => write!(f, "auto"),
293 Self::Enabled => write!(f, "enabled"),
294 Self::Disabled => write!(f, "disabled"),
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::datafusion::create_session_context;
303 use crate::parser::lookup_table::{
304 ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
305 };
306 use arrow::datatypes::{DataType, Field, Schema};
307 use datafusion::prelude::SessionContext;
308 use datafusion_common::tree_node::TreeNode;
309 use datafusion_optimizer::optimizer::OptimizerContext;
310
311 fn test_lookup_info() -> LookupTableInfo {
312 let arrow_schema = Arc::new(Schema::new(vec![
313 Field::new("id", DataType::Int32, false),
314 Field::new("name", DataType::Utf8, true),
315 ]));
316 LookupTableInfo {
317 name: "customers".to_string(),
318 columns: vec![
319 ("id".to_string(), "INT".to_string()),
320 ("name".to_string(), "VARCHAR".to_string()),
321 ],
322 primary_key: vec!["id".to_string()],
323 properties: LookupTableProperties {
324 connector: ConnectorType::PostgresCdc,
325 connection: Some("postgresql://localhost/db".to_string()),
326 strategy: LookupStrategy::Replicated,
327 cache_memory: Some(ByteSize(512 * 1024 * 1024)),
328 cache_disk: None,
329 cache_ttl: None,
330 pushdown_mode: PushdownMode::Auto,
331 },
332 arrow_schema,
333 #[allow(clippy::disallowed_types)] raw_options: std::collections::HashMap::new(),
335 }
336 }
337
338 fn register_test_tables(ctx: &SessionContext) {
339 let orders_schema = Arc::new(Schema::new(vec![
340 Field::new("order_id", DataType::Int64, false),
341 Field::new("customer_id", DataType::Int64, false),
342 Field::new("amount", DataType::Float64, false),
343 ]));
344 let customers_schema = Arc::new(Schema::new(vec![
345 Field::new("id", DataType::Int64, false),
346 Field::new("name", DataType::Utf8, true),
347 ]));
348 ctx.register_batch(
349 "orders",
350 arrow::array::RecordBatch::new_empty(orders_schema),
351 )
352 .unwrap();
353 ctx.register_batch(
354 "customers",
355 arrow::array::RecordBatch::new_empty(customers_schema),
356 )
357 .unwrap();
358 }
359
360 #[tokio::test]
361 async fn test_rewrite_join_on_lookup_table() {
362 let ctx = create_session_context();
363 register_test_tables(&ctx);
364
365 let plan = ctx
366 .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
367 .await
368 .unwrap()
369 .into_unoptimized_plan();
370
371 let mut lookup_tables = HashMap::new();
372 lookup_tables.insert("customers".to_string(), test_lookup_info());
373 let rule = LookupJoinRewriteRule::new(lookup_tables);
374
375 let transformed = plan
376 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
377 .unwrap();
378
379 assert!(transformed.transformed);
381 let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
382 assert!(has_lookup, "Expected LookupJoin in plan");
383 }
384
385 #[tokio::test]
386 async fn test_non_lookup_join_not_rewritten() {
387 let ctx = create_session_context();
388 let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
390 let schema_b = Arc::new(Schema::new(vec![Field::new(
391 "a_id",
392 DataType::Int64,
393 false,
394 )]));
395 ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
396 .unwrap();
397 ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
398 .unwrap();
399
400 let plan = ctx
401 .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
402 .await
403 .unwrap()
404 .into_unoptimized_plan();
405
406 let rule = LookupJoinRewriteRule::new(HashMap::new());
408
409 let transformed = plan
410 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
411 .unwrap();
412
413 assert!(!transformed.transformed);
414 }
415
416 #[tokio::test]
417 async fn test_left_outer_produces_left_outer_type() {
418 let ctx = create_session_context();
419 register_test_tables(&ctx);
420
421 let plan = ctx
422 .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
423 .await
424 .unwrap()
425 .into_unoptimized_plan();
426
427 let mut lookup_tables = HashMap::new();
428 lookup_tables.insert("customers".to_string(), test_lookup_info());
429 let rule = LookupJoinRewriteRule::new(lookup_tables);
430
431 let transformed = plan
432 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
433 .unwrap();
434
435 assert!(transformed.transformed);
436 let debug_str = format!("{:?}", transformed.data);
437 assert!(
438 debug_str.contains("LeftOuter"),
439 "Expected LeftOuter join type, got: {debug_str}"
440 );
441 }
442
443 #[test]
444 fn test_fmt_display_connector_type() {
445 assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
446 assert_eq!(ConnectorType::Redis.to_string(), "redis");
447 assert_eq!(
448 ConnectorType::Custom("my-conn".into()).to_string(),
449 "my-conn"
450 );
451 }
452
453 #[test]
454 fn test_fmt_display_strategy() {
455 assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
456 assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
457 }
458
459 #[test]
460 fn test_fmt_display_pushdown_mode() {
461 assert_eq!(PushdownMode::Auto.to_string(), "auto");
462 assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
463 }
464}