1#[allow(clippy::disallowed_types)] use std::collections::{HashMap, HashSet};
9use std::fmt;
10use std::sync::Arc;
11
12use datafusion::common::{DFSchema, Result};
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
15use datafusion_common::tree_node::Transformed;
16use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
17
18use crate::datafusion::lookup_join::{
19 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
20};
21use crate::planner::LookupTableInfo;
22
23#[derive(Debug)]
26pub struct LookupJoinRewriteRule {
27 lookup_tables: HashMap<String, LookupTableInfo>,
29}
30
31impl LookupJoinRewriteRule {
32 #[must_use]
34 pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
35 Self { lookup_tables }
36 }
37
38 fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
41 if let Some(name) = scan_table_name(&join.right) {
43 if self.lookup_tables.contains_key(&name) {
44 return Some((true, name));
45 }
46 }
47 if let Some(name) = scan_table_name(&join.left) {
49 if self.lookup_tables.contains_key(&name) {
50 return Some((false, name));
51 }
52 }
53 None
54 }
55}
56
57impl OptimizerRule for LookupJoinRewriteRule {
58 fn name(&self) -> &'static str {
59 "lookup_join_rewrite"
60 }
61
62 fn apply_order(&self) -> Option<ApplyOrder> {
63 Some(ApplyOrder::BottomUp)
64 }
65
66 fn rewrite(
67 &self,
68 plan: LogicalPlan,
69 _config: &dyn OptimizerConfig,
70 ) -> Result<Transformed<LogicalPlan>> {
71 let LogicalPlan::Join(join) = &plan else {
72 return Ok(Transformed::no(plan));
73 };
74
75 let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
76 return Ok(Transformed::no(plan));
77 };
78
79 let info = &self.lookup_tables[&table_name];
80
81 let (stream_plan, lookup_plan) = if lookup_is_right {
83 (join.left.as_ref(), join.right.as_ref())
84 } else {
85 (join.right.as_ref(), join.left.as_ref())
86 };
87
88 let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
90 let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
91
92 let lookup_schema = lookup_plan.schema().clone();
93
94 let join_keys: Vec<JoinKeyPair> = join
96 .on
97 .iter()
98 .map(|(left_expr, right_expr)| {
99 if lookup_is_right {
100 JoinKeyPair {
101 stream_expr: left_expr.clone(),
102 lookup_column: right_expr.to_string(),
103 }
104 } else {
105 JoinKeyPair {
106 stream_expr: right_expr.clone(),
107 lookup_column: left_expr.to_string(),
108 }
109 }
110 })
111 .collect();
112
113 let join_type = match join.join_type {
115 datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
116 datafusion::logical_expr::JoinType::Left if lookup_is_right => {
117 LookupJoinType::LeftOuter
118 }
119 datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
120 LookupJoinType::LeftOuter
121 }
122 _ => return Ok(Transformed::no(plan)),
123 };
124
125 let required_columns: HashSet<String> = lookup_schema
127 .fields()
128 .iter()
129 .map(|f| f.name().clone())
130 .collect();
131
132 let stream_schema = stream_plan.schema();
134 let merged_fields: Vec<_> = stream_schema
135 .fields()
136 .iter()
137 .chain(lookup_schema.fields().iter())
138 .cloned()
139 .collect();
140 let output_schema = Arc::new(DFSchema::from_unqualified_fields(
141 merged_fields.into(),
142 HashMap::new(),
143 )?);
144
145 let metadata = LookupTableMetadata {
146 connector: info.properties.connector.to_string(),
147 strategy: info.properties.strategy.to_string(),
148 pushdown_mode: info.properties.pushdown_mode.to_string(),
149 primary_key: info.primary_key.clone(),
150 };
151
152 let node = LookupJoinNode::new(
153 stream_plan.clone(),
154 table_name,
155 lookup_schema,
156 join_keys,
157 join_type,
158 vec![], required_columns,
160 output_schema,
161 metadata,
162 )
163 .with_aliases(lookup_alias, stream_alias);
164
165 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
166 node: Arc::new(node),
167 })))
168 }
169}
170
171#[derive(Debug)]
176pub struct LookupColumnPruningRule;
177
178impl OptimizerRule for LookupColumnPruningRule {
179 fn name(&self) -> &'static str {
180 "lookup_column_pruning"
181 }
182
183 fn apply_order(&self) -> Option<ApplyOrder> {
184 Some(ApplyOrder::TopDown)
185 }
186
187 fn rewrite(
188 &self,
189 plan: LogicalPlan,
190 _config: &dyn OptimizerConfig,
191 ) -> Result<Transformed<LogicalPlan>> {
192 let LogicalPlan::Extension(ext) = &plan else {
193 return Ok(Transformed::no(plan));
194 };
195
196 let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
197 return Ok(Transformed::no(plan));
198 };
199
200 let schema = UserDefinedLogicalNodeCore::schema(node);
205 let used: HashSet<String> = schema
206 .fields()
207 .iter()
208 .filter(|f| node.required_lookup_columns().contains(f.name()))
209 .map(|f| f.name().clone())
210 .collect();
211
212 if used == *node.required_lookup_columns() {
213 return Ok(Transformed::no(plan));
214 }
215
216 let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
218 let pruned = LookupJoinNode::new(
219 node_inputs[0].clone(),
220 node.lookup_table_name().to_string(),
221 node.lookup_schema().clone(),
222 node.join_keys().to_vec(),
223 node.join_type(),
224 node.pushdown_predicates().to_vec(),
225 used,
226 schema.clone(),
227 node.metadata().clone(),
228 )
229 .with_local_predicates(node.local_predicates().to_vec())
230 .with_aliases(
231 node.lookup_alias().map(String::from),
232 node.stream_alias().map(String::from),
233 );
234
235 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
236 node: Arc::new(pruned),
237 })))
238 }
239}
240
241fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
246 match plan {
247 LogicalPlan::TableScan(TableScan { table_name, .. }) => {
248 Some((table_name.table().to_string(), None))
249 }
250 LogicalPlan::SubqueryAlias(alias) => {
251 let alias_name = alias.alias.table().to_string();
252 scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
253 }
254 _ => None,
255 }
256}
257
258fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
260 scan_table_name_and_alias(plan).map(|(name, _)| name)
261}
262
263impl fmt::Display for crate::parser::lookup_table::ConnectorType {
265 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266 match self {
267 Self::PostgresCdc => write!(f, "postgres-cdc"),
268 Self::MysqlCdc => write!(f, "mysql-cdc"),
269 Self::Redis => write!(f, "redis"),
270 Self::S3Parquet => write!(f, "s3-parquet"),
271 Self::DeltaLake => write!(f, "delta-lake"),
272 Self::Static => write!(f, "static"),
273 Self::Custom(s) => write!(f, "{s}"),
274 }
275 }
276}
277
278impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 match self {
281 Self::Replicated => write!(f, "replicated"),
282 Self::Partitioned => write!(f, "partitioned"),
283 Self::OnDemand => write!(f, "on-demand"),
284 }
285 }
286}
287
288impl fmt::Display for crate::parser::lookup_table::PushdownMode {
289 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290 match self {
291 Self::Auto => write!(f, "auto"),
292 Self::Enabled => write!(f, "enabled"),
293 Self::Disabled => write!(f, "disabled"),
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::datafusion::create_session_context;
302 use crate::parser::lookup_table::{
303 ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
304 };
305 use arrow::datatypes::{DataType, Field, Schema};
306 use datafusion::prelude::SessionContext;
307 use datafusion_common::tree_node::TreeNode;
308 use datafusion_optimizer::optimizer::OptimizerContext;
309
310 fn test_lookup_info() -> LookupTableInfo {
311 let arrow_schema = Arc::new(Schema::new(vec![
312 Field::new("id", DataType::Int32, false),
313 Field::new("name", DataType::Utf8, true),
314 ]));
315 LookupTableInfo {
316 name: "customers".to_string(),
317 columns: vec![
318 ("id".to_string(), "INT".to_string()),
319 ("name".to_string(), "VARCHAR".to_string()),
320 ],
321 primary_key: vec!["id".to_string()],
322 properties: LookupTableProperties {
323 connector: ConnectorType::PostgresCdc,
324 connection: Some("postgresql://localhost/db".to_string()),
325 strategy: LookupStrategy::Replicated,
326 cache_memory: Some(ByteSize(512 * 1024 * 1024)),
327 cache_disk: None,
328 cache_ttl: None,
329 pushdown_mode: PushdownMode::Auto,
330 },
331 arrow_schema,
332 #[allow(clippy::disallowed_types)] raw_options: std::collections::HashMap::new(),
334 }
335 }
336
337 fn register_test_tables(ctx: &SessionContext) {
338 let orders_schema = Arc::new(Schema::new(vec![
339 Field::new("order_id", DataType::Int64, false),
340 Field::new("customer_id", DataType::Int64, false),
341 Field::new("amount", DataType::Float64, false),
342 ]));
343 let customers_schema = Arc::new(Schema::new(vec![
344 Field::new("id", DataType::Int64, false),
345 Field::new("name", DataType::Utf8, true),
346 ]));
347 ctx.register_batch(
348 "orders",
349 arrow::array::RecordBatch::new_empty(orders_schema),
350 )
351 .unwrap();
352 ctx.register_batch(
353 "customers",
354 arrow::array::RecordBatch::new_empty(customers_schema),
355 )
356 .unwrap();
357 }
358
359 #[tokio::test]
360 async fn test_rewrite_join_on_lookup_table() {
361 let ctx = create_session_context();
362 register_test_tables(&ctx);
363
364 let plan = ctx
365 .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
366 .await
367 .unwrap()
368 .into_unoptimized_plan();
369
370 let mut lookup_tables = HashMap::new();
371 lookup_tables.insert("customers".to_string(), test_lookup_info());
372 let rule = LookupJoinRewriteRule::new(lookup_tables);
373
374 let transformed = plan
375 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
376 .unwrap();
377
378 assert!(transformed.transformed);
380 let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
381 assert!(has_lookup, "Expected LookupJoin in plan");
382 }
383
384 #[tokio::test]
385 async fn test_non_lookup_join_not_rewritten() {
386 let ctx = create_session_context();
387 let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
389 let schema_b = Arc::new(Schema::new(vec![Field::new(
390 "a_id",
391 DataType::Int64,
392 false,
393 )]));
394 ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
395 .unwrap();
396 ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
397 .unwrap();
398
399 let plan = ctx
400 .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
401 .await
402 .unwrap()
403 .into_unoptimized_plan();
404
405 let rule = LookupJoinRewriteRule::new(HashMap::new());
407
408 let transformed = plan
409 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
410 .unwrap();
411
412 assert!(!transformed.transformed);
413 }
414
415 #[tokio::test]
416 async fn test_left_outer_produces_left_outer_type() {
417 let ctx = create_session_context();
418 register_test_tables(&ctx);
419
420 let plan = ctx
421 .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
422 .await
423 .unwrap()
424 .into_unoptimized_plan();
425
426 let mut lookup_tables = HashMap::new();
427 lookup_tables.insert("customers".to_string(), test_lookup_info());
428 let rule = LookupJoinRewriteRule::new(lookup_tables);
429
430 let transformed = plan
431 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
432 .unwrap();
433
434 assert!(transformed.transformed);
435 let debug_str = format!("{:?}", transformed.data);
436 assert!(
437 debug_str.contains("LeftOuter"),
438 "Expected LeftOuter join type, got: {debug_str}"
439 );
440 }
441
442 #[test]
443 fn test_fmt_display_connector_type() {
444 assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
445 assert_eq!(ConnectorType::Redis.to_string(), "redis");
446 assert_eq!(
447 ConnectorType::Custom("my-conn".into()).to_string(),
448 "my-conn"
449 );
450 }
451
452 #[test]
453 fn test_fmt_display_strategy() {
454 assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
455 assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
456 }
457
458 #[test]
459 fn test_fmt_display_pushdown_mode() {
460 assert_eq!(PushdownMode::Auto.to_string(), "auto");
461 assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
462 }
463}