1#[allow(clippy::disallowed_types)] use std::collections::{HashMap, HashSet};
9use std::fmt;
10use std::sync::Arc;
11
12use datafusion::common::{DFSchema, Result};
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
15use datafusion_common::tree_node::Transformed;
16use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
17
18use crate::datafusion::lookup_join::{
19 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
20};
21use crate::planner::LookupTableInfo;
22
23#[derive(Debug)]
26pub struct LookupJoinRewriteRule {
27 lookup_tables: HashMap<String, LookupTableInfo>,
29}
30
31impl LookupJoinRewriteRule {
32 #[must_use]
34 pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
35 Self { lookup_tables }
36 }
37
38 fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
41 if let Some(name) = scan_table_name(&join.right) {
43 if self.lookup_tables.contains_key(&name) {
44 return Some((true, name));
45 }
46 }
47 if let Some(name) = scan_table_name(&join.left) {
49 if self.lookup_tables.contains_key(&name) {
50 return Some((false, name));
51 }
52 }
53 None
54 }
55}
56
57impl OptimizerRule for LookupJoinRewriteRule {
58 fn name(&self) -> &'static str {
59 "lookup_join_rewrite"
60 }
61
62 fn apply_order(&self) -> Option<ApplyOrder> {
63 Some(ApplyOrder::BottomUp)
64 }
65
66 fn rewrite(
67 &self,
68 plan: LogicalPlan,
69 _config: &dyn OptimizerConfig,
70 ) -> Result<Transformed<LogicalPlan>> {
71 let LogicalPlan::Join(join) = &plan else {
72 return Ok(Transformed::no(plan));
73 };
74
75 let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
76 return Ok(Transformed::no(plan));
77 };
78
79 let info = &self.lookup_tables[&table_name];
80
81 let (stream_plan, lookup_plan) = if lookup_is_right {
83 (join.left.as_ref(), join.right.as_ref())
84 } else {
85 (join.right.as_ref(), join.left.as_ref())
86 };
87
88 let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
90 let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
91
92 let lookup_schema = lookup_plan.schema().clone();
93
94 let join_keys: Vec<JoinKeyPair> = join
96 .on
97 .iter()
98 .map(|(left_expr, right_expr)| {
99 if lookup_is_right {
100 JoinKeyPair {
101 stream_expr: left_expr.clone(),
102 lookup_column: right_expr.to_string(),
103 }
104 } else {
105 JoinKeyPair {
106 stream_expr: right_expr.clone(),
107 lookup_column: left_expr.to_string(),
108 }
109 }
110 })
111 .collect();
112
113 let join_type = match join.join_type {
115 datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
116 datafusion::logical_expr::JoinType::Left if lookup_is_right => {
117 LookupJoinType::LeftOuter
118 }
119 datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
120 LookupJoinType::LeftOuter
121 }
122 _ => return Ok(Transformed::no(plan)),
123 };
124
125 let required_columns: HashSet<String> = lookup_schema
127 .fields()
128 .iter()
129 .map(|f| f.name().clone())
130 .collect();
131
132 let stream_schema = stream_plan.schema();
134 let merged_fields: Vec<_> = stream_schema
135 .fields()
136 .iter()
137 .chain(lookup_schema.fields().iter())
138 .cloned()
139 .collect();
140 let output_schema = Arc::new(DFSchema::from_unqualified_fields(
141 merged_fields.into(),
142 HashMap::new(),
143 )?);
144
145 let metadata = LookupTableMetadata {
146 connector: info.properties.connector.to_string(),
147 strategy: info.properties.strategy.to_string(),
148 pushdown_mode: info.properties.pushdown_mode.to_string(),
149 primary_key: info.primary_key.clone(),
150 };
151
152 let node = LookupJoinNode::new(
153 stream_plan.clone(),
154 table_name,
155 lookup_schema,
156 join_keys,
157 join_type,
158 vec![], required_columns,
160 output_schema,
161 metadata,
162 )
163 .with_aliases(lookup_alias, stream_alias);
164
165 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
166 node: Arc::new(node),
167 })))
168 }
169}
170
171#[derive(Debug)]
176pub struct LookupColumnPruningRule;
177
178impl OptimizerRule for LookupColumnPruningRule {
179 fn name(&self) -> &'static str {
180 "lookup_column_pruning"
181 }
182
183 fn apply_order(&self) -> Option<ApplyOrder> {
184 Some(ApplyOrder::TopDown)
185 }
186
187 fn rewrite(
188 &self,
189 plan: LogicalPlan,
190 _config: &dyn OptimizerConfig,
191 ) -> Result<Transformed<LogicalPlan>> {
192 let LogicalPlan::Extension(ext) = &plan else {
193 return Ok(Transformed::no(plan));
194 };
195
196 let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
197 return Ok(Transformed::no(plan));
198 };
199
200 let schema = UserDefinedLogicalNodeCore::schema(node);
205 let used: HashSet<String> = schema
206 .fields()
207 .iter()
208 .filter(|f| node.required_lookup_columns().contains(f.name()))
209 .map(|f| f.name().clone())
210 .collect();
211
212 if used == *node.required_lookup_columns() {
213 return Ok(Transformed::no(plan));
214 }
215
216 let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
218 let pruned = LookupJoinNode::new(
219 node_inputs[0].clone(),
220 node.lookup_table_name().to_string(),
221 node.lookup_schema().clone(),
222 node.join_keys().to_vec(),
223 node.join_type(),
224 node.pushdown_predicates().to_vec(),
225 used,
226 schema.clone(),
227 node.metadata().clone(),
228 )
229 .with_local_predicates(node.local_predicates().to_vec())
230 .with_aliases(
231 node.lookup_alias().map(String::from),
232 node.stream_alias().map(String::from),
233 );
234
235 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
236 node: Arc::new(pruned),
237 })))
238 }
239}
240
241fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
246 match plan {
247 LogicalPlan::TableScan(TableScan { table_name, .. }) => {
248 Some((table_name.table().to_string(), None))
249 }
250 LogicalPlan::SubqueryAlias(alias) => {
251 let alias_name = alias.alias.table().to_string();
252 scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
253 }
254 _ => None,
255 }
256}
257
258fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
260 scan_table_name_and_alias(plan).map(|(name, _)| name)
261}
262
263impl fmt::Display for crate::parser::lookup_table::ConnectorType {
265 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266 match self {
267 Self::PostgresCdc => write!(f, "postgres-cdc"),
268 Self::MysqlCdc => write!(f, "mysql-cdc"),
269 Self::Redis => write!(f, "redis"),
270 Self::S3Parquet => write!(f, "s3-parquet"),
271 Self::Static => write!(f, "static"),
272 Self::Custom(s) => write!(f, "{s}"),
273 }
274 }
275}
276
277impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 match self {
280 Self::Replicated => write!(f, "replicated"),
281 Self::Partitioned => write!(f, "partitioned"),
282 Self::OnDemand => write!(f, "on-demand"),
283 }
284 }
285}
286
287impl fmt::Display for crate::parser::lookup_table::PushdownMode {
288 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289 match self {
290 Self::Auto => write!(f, "auto"),
291 Self::Enabled => write!(f, "enabled"),
292 Self::Disabled => write!(f, "disabled"),
293 }
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::datafusion::create_session_context;
301 use crate::parser::lookup_table::{
302 ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
303 };
304 use arrow::datatypes::{DataType, Field, Schema};
305 use datafusion::prelude::SessionContext;
306 use datafusion_common::tree_node::TreeNode;
307 use datafusion_optimizer::optimizer::OptimizerContext;
308
309 fn test_lookup_info() -> LookupTableInfo {
310 let arrow_schema = Arc::new(Schema::new(vec![
311 Field::new("id", DataType::Int32, false),
312 Field::new("name", DataType::Utf8, true),
313 ]));
314 LookupTableInfo {
315 name: "customers".to_string(),
316 columns: vec![
317 ("id".to_string(), "INT".to_string()),
318 ("name".to_string(), "VARCHAR".to_string()),
319 ],
320 primary_key: vec!["id".to_string()],
321 properties: LookupTableProperties {
322 connector: ConnectorType::PostgresCdc,
323 connection: Some("postgresql://localhost/db".to_string()),
324 strategy: LookupStrategy::Replicated,
325 cache_memory: Some(ByteSize(512 * 1024 * 1024)),
326 cache_disk: None,
327 cache_ttl: None,
328 pushdown_mode: PushdownMode::Auto,
329 },
330 arrow_schema,
331 #[allow(clippy::disallowed_types)] raw_options: std::collections::HashMap::new(),
333 }
334 }
335
336 fn register_test_tables(ctx: &SessionContext) {
337 let orders_schema = Arc::new(Schema::new(vec![
338 Field::new("order_id", DataType::Int64, false),
339 Field::new("customer_id", DataType::Int64, false),
340 Field::new("amount", DataType::Float64, false),
341 ]));
342 let customers_schema = Arc::new(Schema::new(vec![
343 Field::new("id", DataType::Int64, false),
344 Field::new("name", DataType::Utf8, true),
345 ]));
346 ctx.register_batch(
347 "orders",
348 arrow::array::RecordBatch::new_empty(orders_schema),
349 )
350 .unwrap();
351 ctx.register_batch(
352 "customers",
353 arrow::array::RecordBatch::new_empty(customers_schema),
354 )
355 .unwrap();
356 }
357
358 #[tokio::test]
359 async fn test_rewrite_join_on_lookup_table() {
360 let ctx = create_session_context();
361 register_test_tables(&ctx);
362
363 let plan = ctx
364 .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
365 .await
366 .unwrap()
367 .into_unoptimized_plan();
368
369 let mut lookup_tables = HashMap::new();
370 lookup_tables.insert("customers".to_string(), test_lookup_info());
371 let rule = LookupJoinRewriteRule::new(lookup_tables);
372
373 let transformed = plan
374 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
375 .unwrap();
376
377 assert!(transformed.transformed);
379 let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
380 assert!(has_lookup, "Expected LookupJoin in plan");
381 }
382
383 #[tokio::test]
384 async fn test_non_lookup_join_not_rewritten() {
385 let ctx = create_session_context();
386 let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
388 let schema_b = Arc::new(Schema::new(vec![Field::new(
389 "a_id",
390 DataType::Int64,
391 false,
392 )]));
393 ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
394 .unwrap();
395 ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
396 .unwrap();
397
398 let plan = ctx
399 .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
400 .await
401 .unwrap()
402 .into_unoptimized_plan();
403
404 let rule = LookupJoinRewriteRule::new(HashMap::new());
406
407 let transformed = plan
408 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
409 .unwrap();
410
411 assert!(!transformed.transformed);
412 }
413
414 #[tokio::test]
415 async fn test_left_outer_produces_left_outer_type() {
416 let ctx = create_session_context();
417 register_test_tables(&ctx);
418
419 let plan = ctx
420 .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
421 .await
422 .unwrap()
423 .into_unoptimized_plan();
424
425 let mut lookup_tables = HashMap::new();
426 lookup_tables.insert("customers".to_string(), test_lookup_info());
427 let rule = LookupJoinRewriteRule::new(lookup_tables);
428
429 let transformed = plan
430 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
431 .unwrap();
432
433 assert!(transformed.transformed);
434 let debug_str = format!("{:?}", transformed.data);
435 assert!(
436 debug_str.contains("LeftOuter"),
437 "Expected LeftOuter join type, got: {debug_str}"
438 );
439 }
440
441 #[test]
442 fn test_fmt_display_connector_type() {
443 assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
444 assert_eq!(ConnectorType::Redis.to_string(), "redis");
445 assert_eq!(
446 ConnectorType::Custom("my-conn".into()).to_string(),
447 "my-conn"
448 );
449 }
450
451 #[test]
452 fn test_fmt_display_strategy() {
453 assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
454 assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
455 }
456
457 #[test]
458 fn test_fmt_display_pushdown_mode() {
459 assert_eq!(PushdownMode::Auto.to_string(), "auto");
460 assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
461 }
462}