1use std::collections::{HashMap, HashSet};
8use std::fmt;
9use std::sync::Arc;
10
11use datafusion::common::{DFSchema, Result};
12use datafusion::logical_expr::logical_plan::LogicalPlan;
13use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
14use datafusion_common::tree_node::Transformed;
15use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
16
17use crate::datafusion::lookup_join::{
18 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
19};
20use crate::planner::LookupTableInfo;
21
22#[derive(Debug)]
25pub struct LookupJoinRewriteRule {
26 lookup_tables: HashMap<String, LookupTableInfo>,
28}
29
30impl LookupJoinRewriteRule {
31 #[must_use]
33 pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
34 Self { lookup_tables }
35 }
36
37 fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
40 if let Some(name) = scan_table_name(&join.right) {
42 if self.lookup_tables.contains_key(&name) {
43 return Some((true, name));
44 }
45 }
46 if let Some(name) = scan_table_name(&join.left) {
48 if self.lookup_tables.contains_key(&name) {
49 return Some((false, name));
50 }
51 }
52 None
53 }
54}
55
56impl OptimizerRule for LookupJoinRewriteRule {
57 fn name(&self) -> &'static str {
58 "lookup_join_rewrite"
59 }
60
61 fn apply_order(&self) -> Option<ApplyOrder> {
62 Some(ApplyOrder::BottomUp)
63 }
64
65 fn rewrite(
66 &self,
67 plan: LogicalPlan,
68 _config: &dyn OptimizerConfig,
69 ) -> Result<Transformed<LogicalPlan>> {
70 let LogicalPlan::Join(join) = &plan else {
71 return Ok(Transformed::no(plan));
72 };
73
74 let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
75 return Ok(Transformed::no(plan));
76 };
77
78 let info = &self.lookup_tables[&table_name];
79
80 let (stream_plan, lookup_plan) = if lookup_is_right {
82 (join.left.as_ref(), join.right.as_ref())
83 } else {
84 (join.right.as_ref(), join.left.as_ref())
85 };
86
87 let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
89 let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
90
91 let lookup_schema = lookup_plan.schema().clone();
92
93 let join_keys: Vec<JoinKeyPair> = join
95 .on
96 .iter()
97 .map(|(left_expr, right_expr)| {
98 if lookup_is_right {
99 JoinKeyPair {
100 stream_expr: left_expr.clone(),
101 lookup_column: right_expr.to_string(),
102 }
103 } else {
104 JoinKeyPair {
105 stream_expr: right_expr.clone(),
106 lookup_column: left_expr.to_string(),
107 }
108 }
109 })
110 .collect();
111
112 let join_type = match join.join_type {
114 datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
115 datafusion::logical_expr::JoinType::Left if lookup_is_right => {
116 LookupJoinType::LeftOuter
117 }
118 datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
119 LookupJoinType::LeftOuter
120 }
121 _ => return Ok(Transformed::no(plan)),
122 };
123
124 let required_columns: HashSet<String> = lookup_schema
126 .fields()
127 .iter()
128 .map(|f| f.name().clone())
129 .collect();
130
131 let stream_schema = stream_plan.schema();
133 let merged_fields: Vec<_> = stream_schema
134 .fields()
135 .iter()
136 .chain(lookup_schema.fields().iter())
137 .cloned()
138 .collect();
139 let output_schema = Arc::new(DFSchema::from_unqualified_fields(
140 merged_fields.into(),
141 HashMap::new(),
142 )?);
143
144 let metadata = LookupTableMetadata {
145 connector: info.properties.connector.to_string(),
146 strategy: info.properties.strategy.to_string(),
147 pushdown_mode: info.properties.pushdown_mode.to_string(),
148 primary_key: info.primary_key.clone(),
149 };
150
151 let node = LookupJoinNode::new(
152 stream_plan.clone(),
153 table_name,
154 lookup_schema,
155 join_keys,
156 join_type,
157 vec![], required_columns,
159 output_schema,
160 metadata,
161 )
162 .with_aliases(lookup_alias, stream_alias);
163
164 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
165 node: Arc::new(node),
166 })))
167 }
168}
169
170#[derive(Debug)]
175pub struct LookupColumnPruningRule;
176
177impl OptimizerRule for LookupColumnPruningRule {
178 fn name(&self) -> &'static str {
179 "lookup_column_pruning"
180 }
181
182 fn apply_order(&self) -> Option<ApplyOrder> {
183 Some(ApplyOrder::TopDown)
184 }
185
186 fn rewrite(
187 &self,
188 plan: LogicalPlan,
189 _config: &dyn OptimizerConfig,
190 ) -> Result<Transformed<LogicalPlan>> {
191 let LogicalPlan::Extension(ext) = &plan else {
192 return Ok(Transformed::no(plan));
193 };
194
195 let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
196 return Ok(Transformed::no(plan));
197 };
198
199 let schema = UserDefinedLogicalNodeCore::schema(node);
204 let used: HashSet<String> = schema
205 .fields()
206 .iter()
207 .filter(|f| node.required_lookup_columns().contains(f.name()))
208 .map(|f| f.name().clone())
209 .collect();
210
211 if used == *node.required_lookup_columns() {
212 return Ok(Transformed::no(plan));
213 }
214
215 let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
217 let pruned = LookupJoinNode::new(
218 node_inputs[0].clone(),
219 node.lookup_table_name().to_string(),
220 node.lookup_schema().clone(),
221 node.join_keys().to_vec(),
222 node.join_type(),
223 node.pushdown_predicates().to_vec(),
224 used,
225 schema.clone(),
226 node.metadata().clone(),
227 )
228 .with_local_predicates(node.local_predicates().to_vec())
229 .with_aliases(
230 node.lookup_alias().map(String::from),
231 node.stream_alias().map(String::from),
232 );
233
234 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
235 node: Arc::new(pruned),
236 })))
237 }
238}
239
240fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
245 match plan {
246 LogicalPlan::TableScan(TableScan { table_name, .. }) => {
247 Some((table_name.table().to_string(), None))
248 }
249 LogicalPlan::SubqueryAlias(alias) => {
250 let alias_name = alias.alias.table().to_string();
251 scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
252 }
253 _ => None,
254 }
255}
256
257fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
259 scan_table_name_and_alias(plan).map(|(name, _)| name)
260}
261
262impl fmt::Display for crate::parser::lookup_table::ConnectorType {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 match self {
266 Self::PostgresCdc => write!(f, "postgres-cdc"),
267 Self::MysqlCdc => write!(f, "mysql-cdc"),
268 Self::Redis => write!(f, "redis"),
269 Self::S3Parquet => write!(f, "s3-parquet"),
270 Self::Static => write!(f, "static"),
271 Self::Custom(s) => write!(f, "{s}"),
272 }
273 }
274}
275
276impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278 match self {
279 Self::Replicated => write!(f, "replicated"),
280 Self::Partitioned => write!(f, "partitioned"),
281 Self::OnDemand => write!(f, "on-demand"),
282 }
283 }
284}
285
286impl fmt::Display for crate::parser::lookup_table::PushdownMode {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 match self {
289 Self::Auto => write!(f, "auto"),
290 Self::Enabled => write!(f, "enabled"),
291 Self::Disabled => write!(f, "disabled"),
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::parser::lookup_table::{
300 ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
301 };
302 use arrow::datatypes::{DataType, Field, Schema};
303 use datafusion::prelude::SessionContext;
304 use datafusion_common::tree_node::TreeNode;
305 use datafusion_optimizer::optimizer::OptimizerContext;
306
307 fn test_lookup_info() -> LookupTableInfo {
308 LookupTableInfo {
309 name: "customers".to_string(),
310 columns: vec![
311 ("id".to_string(), "INT".to_string()),
312 ("name".to_string(), "VARCHAR".to_string()),
313 ],
314 primary_key: vec!["id".to_string()],
315 properties: LookupTableProperties {
316 connector: ConnectorType::PostgresCdc,
317 connection: Some("postgresql://localhost/db".to_string()),
318 strategy: LookupStrategy::Replicated,
319 cache_memory: Some(ByteSize(512 * 1024 * 1024)),
320 cache_disk: None,
321 cache_ttl: None,
322 pushdown_mode: PushdownMode::Auto,
323 },
324 }
325 }
326
327 fn register_test_tables(ctx: &SessionContext) {
328 let orders_schema = Arc::new(Schema::new(vec![
329 Field::new("order_id", DataType::Int64, false),
330 Field::new("customer_id", DataType::Int64, false),
331 Field::new("amount", DataType::Float64, false),
332 ]));
333 let customers_schema = Arc::new(Schema::new(vec![
334 Field::new("id", DataType::Int64, false),
335 Field::new("name", DataType::Utf8, true),
336 ]));
337 ctx.register_batch(
338 "orders",
339 arrow::array::RecordBatch::new_empty(orders_schema),
340 )
341 .unwrap();
342 ctx.register_batch(
343 "customers",
344 arrow::array::RecordBatch::new_empty(customers_schema),
345 )
346 .unwrap();
347 }
348
349 #[tokio::test]
350 async fn test_rewrite_join_on_lookup_table() {
351 let ctx = SessionContext::new();
352 register_test_tables(&ctx);
353
354 let plan = ctx
355 .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
356 .await
357 .unwrap()
358 .into_unoptimized_plan();
359
360 let mut lookup_tables = HashMap::new();
361 lookup_tables.insert("customers".to_string(), test_lookup_info());
362 let rule = LookupJoinRewriteRule::new(lookup_tables);
363
364 let transformed = plan
365 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
366 .unwrap();
367
368 assert!(transformed.transformed);
370 let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
371 assert!(has_lookup, "Expected LookupJoin in plan");
372 }
373
374 #[tokio::test]
375 async fn test_non_lookup_join_not_rewritten() {
376 let ctx = SessionContext::new();
377 let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
379 let schema_b = Arc::new(Schema::new(vec![Field::new(
380 "a_id",
381 DataType::Int64,
382 false,
383 )]));
384 ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
385 .unwrap();
386 ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
387 .unwrap();
388
389 let plan = ctx
390 .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
391 .await
392 .unwrap()
393 .into_unoptimized_plan();
394
395 let rule = LookupJoinRewriteRule::new(HashMap::new());
397
398 let transformed = plan
399 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
400 .unwrap();
401
402 assert!(!transformed.transformed);
403 }
404
405 #[tokio::test]
406 async fn test_left_outer_produces_left_outer_type() {
407 let ctx = SessionContext::new();
408 register_test_tables(&ctx);
409
410 let plan = ctx
411 .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
412 .await
413 .unwrap()
414 .into_unoptimized_plan();
415
416 let mut lookup_tables = HashMap::new();
417 lookup_tables.insert("customers".to_string(), test_lookup_info());
418 let rule = LookupJoinRewriteRule::new(lookup_tables);
419
420 let transformed = plan
421 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
422 .unwrap();
423
424 assert!(transformed.transformed);
425 let debug_str = format!("{:?}", transformed.data);
426 assert!(
427 debug_str.contains("LeftOuter"),
428 "Expected LeftOuter join type, got: {debug_str}"
429 );
430 }
431
432 #[test]
433 fn test_fmt_display_connector_type() {
434 assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
435 assert_eq!(ConnectorType::Redis.to_string(), "redis");
436 assert_eq!(
437 ConnectorType::Custom("my-conn".into()).to_string(),
438 "my-conn"
439 );
440 }
441
442 #[test]
443 fn test_fmt_display_strategy() {
444 assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
445 assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
446 }
447
448 #[test]
449 fn test_fmt_display_pushdown_mode() {
450 assert_eq!(PushdownMode::Auto.to_string(), "auto");
451 assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
452 }
453}