1use std::collections::{HashMap, HashSet};
8use std::fmt;
9use std::sync::Arc;
10
11use datafusion::common::{DFSchema, Result};
12use datafusion::logical_expr::logical_plan::LogicalPlan;
13use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
14use datafusion_common::tree_node::Transformed;
15use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
16
17use crate::datafusion::lookup_join::{
18 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
19};
20use crate::planner::LookupTableInfo;
21
22#[derive(Debug)]
25pub struct LookupJoinRewriteRule {
26 lookup_tables: HashMap<String, LookupTableInfo>,
28}
29
30impl LookupJoinRewriteRule {
31 #[must_use]
33 pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
34 Self { lookup_tables }
35 }
36
37 fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
40 if let Some(name) = scan_table_name(&join.right) {
42 if self.lookup_tables.contains_key(&name) {
43 return Some((true, name));
44 }
45 }
46 if let Some(name) = scan_table_name(&join.left) {
48 if self.lookup_tables.contains_key(&name) {
49 return Some((false, name));
50 }
51 }
52 None
53 }
54}
55
56impl OptimizerRule for LookupJoinRewriteRule {
57 fn name(&self) -> &'static str {
58 "lookup_join_rewrite"
59 }
60
61 fn apply_order(&self) -> Option<ApplyOrder> {
62 Some(ApplyOrder::BottomUp)
63 }
64
65 fn rewrite(
66 &self,
67 plan: LogicalPlan,
68 _config: &dyn OptimizerConfig,
69 ) -> Result<Transformed<LogicalPlan>> {
70 let LogicalPlan::Join(join) = &plan else {
71 return Ok(Transformed::no(plan));
72 };
73
74 let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
75 return Ok(Transformed::no(plan));
76 };
77
78 let info = &self.lookup_tables[&table_name];
79
80 let (stream_plan, lookup_plan) = if lookup_is_right {
82 (join.left.as_ref(), join.right.as_ref())
83 } else {
84 (join.right.as_ref(), join.left.as_ref())
85 };
86
87 let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
89 let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
90
91 let lookup_schema = lookup_plan.schema().clone();
92
93 let join_keys: Vec<JoinKeyPair> = join
95 .on
96 .iter()
97 .map(|(left_expr, right_expr)| {
98 if lookup_is_right {
99 JoinKeyPair {
100 stream_expr: left_expr.clone(),
101 lookup_column: right_expr.to_string(),
102 }
103 } else {
104 JoinKeyPair {
105 stream_expr: right_expr.clone(),
106 lookup_column: left_expr.to_string(),
107 }
108 }
109 })
110 .collect();
111
112 let join_type = match join.join_type {
114 datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
115 datafusion::logical_expr::JoinType::Left if lookup_is_right => {
116 LookupJoinType::LeftOuter
117 }
118 datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
119 LookupJoinType::LeftOuter
120 }
121 _ => return Ok(Transformed::no(plan)),
122 };
123
124 let required_columns: HashSet<String> = lookup_schema
126 .fields()
127 .iter()
128 .map(|f| f.name().clone())
129 .collect();
130
131 let stream_schema = stream_plan.schema();
133 let merged_fields: Vec<_> = stream_schema
134 .fields()
135 .iter()
136 .chain(lookup_schema.fields().iter())
137 .cloned()
138 .collect();
139 let output_schema = Arc::new(DFSchema::from_unqualified_fields(
140 merged_fields.into(),
141 HashMap::new(),
142 )?);
143
144 let metadata = LookupTableMetadata {
145 connector: info.properties.connector.to_string(),
146 strategy: info.properties.strategy.to_string(),
147 pushdown_mode: info.properties.pushdown_mode.to_string(),
148 primary_key: info.primary_key.clone(),
149 };
150
151 let node = LookupJoinNode::new(
152 stream_plan.clone(),
153 table_name,
154 lookup_schema,
155 join_keys,
156 join_type,
157 vec![], required_columns,
159 output_schema,
160 metadata,
161 )
162 .with_aliases(lookup_alias, stream_alias);
163
164 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
165 node: Arc::new(node),
166 })))
167 }
168}
169
170#[derive(Debug)]
175pub struct LookupColumnPruningRule;
176
177impl OptimizerRule for LookupColumnPruningRule {
178 fn name(&self) -> &'static str {
179 "lookup_column_pruning"
180 }
181
182 fn apply_order(&self) -> Option<ApplyOrder> {
183 Some(ApplyOrder::TopDown)
184 }
185
186 fn rewrite(
187 &self,
188 plan: LogicalPlan,
189 _config: &dyn OptimizerConfig,
190 ) -> Result<Transformed<LogicalPlan>> {
191 let LogicalPlan::Extension(ext) = &plan else {
192 return Ok(Transformed::no(plan));
193 };
194
195 let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
196 return Ok(Transformed::no(plan));
197 };
198
199 let schema = UserDefinedLogicalNodeCore::schema(node);
204 let used: HashSet<String> = schema
205 .fields()
206 .iter()
207 .filter(|f| node.required_lookup_columns().contains(f.name()))
208 .map(|f| f.name().clone())
209 .collect();
210
211 if used == *node.required_lookup_columns() {
212 return Ok(Transformed::no(plan));
213 }
214
215 let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
217 let pruned = LookupJoinNode::new(
218 node_inputs[0].clone(),
219 node.lookup_table_name().to_string(),
220 node.lookup_schema().clone(),
221 node.join_keys().to_vec(),
222 node.join_type(),
223 node.pushdown_predicates().to_vec(),
224 used,
225 schema.clone(),
226 node.metadata().clone(),
227 )
228 .with_local_predicates(node.local_predicates().to_vec())
229 .with_aliases(
230 node.lookup_alias().map(String::from),
231 node.stream_alias().map(String::from),
232 );
233
234 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
235 node: Arc::new(pruned),
236 })))
237 }
238}
239
240fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
245 match plan {
246 LogicalPlan::TableScan(TableScan { table_name, .. }) => {
247 Some((table_name.table().to_string(), None))
248 }
249 LogicalPlan::SubqueryAlias(alias) => {
250 let alias_name = alias.alias.table().to_string();
251 scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
252 }
253 _ => None,
254 }
255}
256
257fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
259 scan_table_name_and_alias(plan).map(|(name, _)| name)
260}
261
262impl fmt::Display for crate::parser::lookup_table::ConnectorType {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 match self {
266 Self::PostgresCdc => write!(f, "postgres-cdc"),
267 Self::MysqlCdc => write!(f, "mysql-cdc"),
268 Self::Redis => write!(f, "redis"),
269 Self::S3Parquet => write!(f, "s3-parquet"),
270 Self::Static => write!(f, "static"),
271 Self::Custom(s) => write!(f, "{s}"),
272 }
273 }
274}
275
276impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278 match self {
279 Self::Replicated => write!(f, "replicated"),
280 Self::Partitioned => write!(f, "partitioned"),
281 Self::OnDemand => write!(f, "on-demand"),
282 }
283 }
284}
285
286impl fmt::Display for crate::parser::lookup_table::PushdownMode {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 match self {
289 Self::Auto => write!(f, "auto"),
290 Self::Enabled => write!(f, "enabled"),
291 Self::Disabled => write!(f, "disabled"),
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::datafusion::create_session_context;
300 use crate::parser::lookup_table::{
301 ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
302 };
303 use arrow::datatypes::{DataType, Field, Schema};
304 use datafusion::prelude::SessionContext;
305 use datafusion_common::tree_node::TreeNode;
306 use datafusion_optimizer::optimizer::OptimizerContext;
307
308 fn test_lookup_info() -> LookupTableInfo {
309 LookupTableInfo {
310 name: "customers".to_string(),
311 columns: vec![
312 ("id".to_string(), "INT".to_string()),
313 ("name".to_string(), "VARCHAR".to_string()),
314 ],
315 primary_key: vec!["id".to_string()],
316 properties: LookupTableProperties {
317 connector: ConnectorType::PostgresCdc,
318 connection: Some("postgresql://localhost/db".to_string()),
319 strategy: LookupStrategy::Replicated,
320 cache_memory: Some(ByteSize(512 * 1024 * 1024)),
321 cache_disk: None,
322 cache_ttl: None,
323 pushdown_mode: PushdownMode::Auto,
324 },
325 }
326 }
327
328 fn register_test_tables(ctx: &SessionContext) {
329 let orders_schema = Arc::new(Schema::new(vec![
330 Field::new("order_id", DataType::Int64, false),
331 Field::new("customer_id", DataType::Int64, false),
332 Field::new("amount", DataType::Float64, false),
333 ]));
334 let customers_schema = Arc::new(Schema::new(vec![
335 Field::new("id", DataType::Int64, false),
336 Field::new("name", DataType::Utf8, true),
337 ]));
338 ctx.register_batch(
339 "orders",
340 arrow::array::RecordBatch::new_empty(orders_schema),
341 )
342 .unwrap();
343 ctx.register_batch(
344 "customers",
345 arrow::array::RecordBatch::new_empty(customers_schema),
346 )
347 .unwrap();
348 }
349
350 #[tokio::test]
351 async fn test_rewrite_join_on_lookup_table() {
352 let ctx = create_session_context();
353 register_test_tables(&ctx);
354
355 let plan = ctx
356 .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
357 .await
358 .unwrap()
359 .into_unoptimized_plan();
360
361 let mut lookup_tables = HashMap::new();
362 lookup_tables.insert("customers".to_string(), test_lookup_info());
363 let rule = LookupJoinRewriteRule::new(lookup_tables);
364
365 let transformed = plan
366 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
367 .unwrap();
368
369 assert!(transformed.transformed);
371 let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
372 assert!(has_lookup, "Expected LookupJoin in plan");
373 }
374
375 #[tokio::test]
376 async fn test_non_lookup_join_not_rewritten() {
377 let ctx = create_session_context();
378 let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
380 let schema_b = Arc::new(Schema::new(vec![Field::new(
381 "a_id",
382 DataType::Int64,
383 false,
384 )]));
385 ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
386 .unwrap();
387 ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
388 .unwrap();
389
390 let plan = ctx
391 .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
392 .await
393 .unwrap()
394 .into_unoptimized_plan();
395
396 let rule = LookupJoinRewriteRule::new(HashMap::new());
398
399 let transformed = plan
400 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
401 .unwrap();
402
403 assert!(!transformed.transformed);
404 }
405
406 #[tokio::test]
407 async fn test_left_outer_produces_left_outer_type() {
408 let ctx = create_session_context();
409 register_test_tables(&ctx);
410
411 let plan = ctx
412 .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
413 .await
414 .unwrap()
415 .into_unoptimized_plan();
416
417 let mut lookup_tables = HashMap::new();
418 lookup_tables.insert("customers".to_string(), test_lookup_info());
419 let rule = LookupJoinRewriteRule::new(lookup_tables);
420
421 let transformed = plan
422 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
423 .unwrap();
424
425 assert!(transformed.transformed);
426 let debug_str = format!("{:?}", transformed.data);
427 assert!(
428 debug_str.contains("LeftOuter"),
429 "Expected LeftOuter join type, got: {debug_str}"
430 );
431 }
432
433 #[test]
434 fn test_fmt_display_connector_type() {
435 assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
436 assert_eq!(ConnectorType::Redis.to_string(), "redis");
437 assert_eq!(
438 ConnectorType::Custom("my-conn".into()).to_string(),
439 "my-conn"
440 );
441 }
442
443 #[test]
444 fn test_fmt_display_strategy() {
445 assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
446 assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
447 }
448
449 #[test]
450 fn test_fmt_display_pushdown_mode() {
451 assert_eq!(PushdownMode::Auto.to_string(), "auto");
452 assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
453 }
454}