1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::{DFSchema, assert_or_internal_err};
23use datafusion_common::{NullEquality, Result};
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27type EquijoinPredicate = (Expr, Expr);
29
30#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45 #[expect(missing_docs)]
46 pub fn new() -> Self {
47 Self {}
48 }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn name(&self) -> &str {
57 "extract_equijoin_predicate"
58 }
59
60 fn apply_order(&self) -> Option<ApplyOrder> {
61 Some(ApplyOrder::BottomUp)
62 }
63
64 fn rewrite(
65 &self,
66 plan: LogicalPlan,
67 _config: &dyn OptimizerConfig,
68 ) -> Result<Transformed<LogicalPlan>> {
69 match plan {
70 LogicalPlan::Join(Join {
71 left,
72 right,
73 mut on,
74 filter: Some(expr),
75 join_type,
76 join_constraint,
77 schema,
78 null_equality,
79 }) => {
80 let left_schema = left.schema();
81 let right_schema = right.schema();
82 let (equijoin_predicates, non_equijoin_expr) =
83 split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85 if on.is_empty()
94 && equijoin_predicates.is_empty()
95 && non_equijoin_expr.is_some()
96 {
97 let expr = non_equijoin_expr.clone().unwrap();
99 let (equijoin_predicates, non_equijoin_expr) =
100 split_is_not_distinct_from_and_other_join_predicate(
101 expr,
102 left_schema,
103 right_schema,
104 )?;
105
106 if !equijoin_predicates.is_empty() {
107 on.extend(equijoin_predicates);
108
109 return Ok(Transformed::yes(LogicalPlan::Join(Join {
110 left,
111 right,
112 on,
113 filter: non_equijoin_expr,
114 join_type,
115 join_constraint,
116 schema,
117 null_equality: NullEquality::NullEqualsNull,
120 })));
121 }
122 }
123
124 if !equijoin_predicates.is_empty() {
125 on.extend(equijoin_predicates);
126 Ok(Transformed::yes(LogicalPlan::Join(Join {
127 left,
128 right,
129 on,
130 filter: non_equijoin_expr,
131 join_type,
132 join_constraint,
133 schema,
134 null_equality,
135 })))
136 } else {
137 Ok(Transformed::no(LogicalPlan::Join(Join {
138 left,
139 right,
140 on,
141 filter: non_equijoin_expr,
142 join_type,
143 join_constraint,
144 schema,
145 null_equality,
146 })))
147 }
148 }
149 _ => Ok(Transformed::no(plan)),
150 }
151 }
152}
153
154fn split_eq_and_noneq_join_predicate(
174 filter: Expr,
175 left_schema: &DFSchema,
176 right_schema: &DFSchema,
177) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
178 split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq)
179}
180
181fn split_is_not_distinct_from_and_other_join_predicate(
207 filter: Expr,
208 left_schema: &DFSchema,
209 right_schema: &DFSchema,
210) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
211 split_op_and_other_join_predicates(
212 filter,
213 left_schema,
214 right_schema,
215 Operator::IsNotDistinctFrom,
216 )
217}
218
219fn split_op_and_other_join_predicates(
221 filter: Expr,
222 left_schema: &DFSchema,
223 right_schema: &DFSchema,
224 operator: Operator,
225) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
226 assert_or_internal_err!(
227 matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom),
228 "split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \
229 but received: {:?}",
230 operator
231 );
232
233 let exprs = split_conjunction_owned(filter);
234
235 let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
237 let mut accum_filters: Vec<Expr> = vec![];
238 for expr in exprs {
239 match expr {
240 Expr::BinaryExpr(BinaryExpr {
241 ref left,
242 ref op,
243 ref right,
244 }) if *op == operator => {
245 let join_key_pair =
246 find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
247
248 if let Some((left_expr, right_expr)) = join_key_pair {
249 let left_expr_type = left_expr.get_type(left_schema)?;
250 let right_expr_type = right_expr.get_type(right_schema)?;
251
252 if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
253 accum_join_keys.push((left_expr, right_expr));
254 } else {
255 accum_filters.push(expr);
256 }
257 } else {
258 accum_filters.push(expr);
259 }
260 }
261 _ => accum_filters.push(expr),
262 }
263 }
264
265 let result_filter = accum_filters.into_iter().reduce(Expr::and);
266 Ok((accum_join_keys, result_filter))
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::assert_optimized_plan_eq_display_indent_snapshot;
273 use crate::test::*;
274 use arrow::datatypes::DataType;
275 use datafusion_expr::{
276 JoinType, col, lit, logical_plan::builder::LogicalPlanBuilder,
277 };
278 use std::sync::Arc;
279
280 macro_rules! assert_optimized_plan_equal {
281 (
282 $plan:expr,
283 @ $expected:literal $(,)?
284 ) => {{
285 let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
286 assert_optimized_plan_eq_display_indent_snapshot!(
287 rule,
288 $plan,
289 @ $expected,
290 )
291 }};
292 }
293
294 #[test]
295 fn join_with_only_column_equi_predicate() -> Result<()> {
296 let t1 = test_table_scan_with_name("t1")?;
297 let t2 = test_table_scan_with_name("t2")?;
298
299 let plan = LogicalPlanBuilder::from(t1)
300 .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
301 .build()?;
302
303 assert_optimized_plan_equal!(
304 plan,
305 @r"
306 Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
307 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
308 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
309 "
310 )
311 }
312
313 #[test]
314 fn join_with_only_equi_expr_predicate() -> Result<()> {
315 let t1 = test_table_scan_with_name("t1")?;
316 let t2 = test_table_scan_with_name("t2")?;
317
318 let plan = LogicalPlanBuilder::from(t1)
319 .join_on(
320 t2,
321 JoinType::Left,
322 Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
323 )?
324 .build()?;
325
326 assert_optimized_plan_equal!(
327 plan,
328 @r"
329 Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
330 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
331 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
332 "
333 )
334 }
335
336 #[test]
337 fn join_with_only_none_equi_predicate() -> Result<()> {
338 let t1 = test_table_scan_with_name("t1")?;
339 let t2 = test_table_scan_with_name("t2")?;
340
341 let plan = LogicalPlanBuilder::from(t1)
342 .join_on(
343 t2,
344 JoinType::Left,
345 Some(
346 (col("t1.a") + lit(10i64))
347 .gt_eq(col("t2.a") * lit(2u32))
348 .and(col("t1.b").lt(lit(100i32))),
349 ),
350 )?
351 .build()?;
352
353 assert_optimized_plan_equal!(
354 plan,
355 @r"
356 Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
357 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
358 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
359 "
360 )
361 }
362
363 #[test]
364 fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
365 let t1 = test_table_scan_with_name("t1")?;
366 let t2 = test_table_scan_with_name("t2")?;
367
368 let plan = LogicalPlanBuilder::from(t1)
369 .join_with_expr_keys(
370 t2,
371 JoinType::Left,
372 (
373 vec![col("t1.a") + lit(11u32)],
374 vec![col("t2.a") * lit(2u32)],
375 ),
376 Some(
377 (col("t1.a") + lit(10i64))
378 .eq(col("t2.a") * lit(2u32))
379 .and(col("t1.b").lt(lit(100i32))),
380 ),
381 )?
382 .build()?;
383
384 assert_optimized_plan_equal!(
385 plan,
386 @r"
387 Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
388 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
389 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
390 "
391 )
392 }
393
394 #[test]
395 fn join_with_and_or_filter() -> Result<()> {
396 let t1 = test_table_scan_with_name("t1")?;
397 let t2 = test_table_scan_with_name("t2")?;
398
399 let plan = LogicalPlanBuilder::from(t1)
400 .join_on(
401 t2,
402 JoinType::Left,
403 Some(
404 col("t1.c")
405 .eq(col("t2.c"))
406 .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
407 .and(
408 col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
409 ),
410 ),
411 )?
412 .build()?;
413
414 assert_optimized_plan_equal!(
415 plan,
416 @r"
417 Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
418 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
419 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
420 "
421 )
422 }
423
424 #[test]
425 fn join_with_multiple_table() -> Result<()> {
426 let t1 = test_table_scan_with_name("t1")?;
427 let t2 = test_table_scan_with_name("t2")?;
428 let t3 = test_table_scan_with_name("t3")?;
429
430 let input = LogicalPlanBuilder::from(t2)
431 .join_on(
432 t3,
433 JoinType::Left,
434 Some(
435 col("t2.a")
436 .eq(col("t3.a"))
437 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
438 ),
439 )?
440 .build()?;
441 let plan = LogicalPlanBuilder::from(t1)
442 .join_on(
443 input,
444 JoinType::Left,
445 Some(
446 col("t1.a")
447 .eq(col("t2.a"))
448 .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
449 ),
450 )?
451 .build()?;
452
453 assert_optimized_plan_equal!(
454 plan,
455 @r"
456 Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
457 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
458 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
459 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
460 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
461 "
462 )
463 }
464
465 #[test]
466 fn join_with_multiple_table_and_eq_filter() -> Result<()> {
467 let t1 = test_table_scan_with_name("t1")?;
468 let t2 = test_table_scan_with_name("t2")?;
469 let t3 = test_table_scan_with_name("t3")?;
470
471 let input = LogicalPlanBuilder::from(t2)
472 .join_on(
473 t3,
474 JoinType::Left,
475 Some(
476 col("t2.a")
477 .eq(col("t3.a"))
478 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
479 ),
480 )?
481 .build()?;
482 let plan = LogicalPlanBuilder::from(t1)
483 .join_on(
484 input,
485 JoinType::Left,
486 Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
487 )?
488 .build()?;
489
490 assert_optimized_plan_equal!(
491 plan,
492 @r"
493 Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
494 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
495 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
496 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
497 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
498 "
499 )
500 }
501
502 #[test]
503 fn join_with_alias_filter() -> Result<()> {
504 let t1 = test_table_scan_with_name("t1")?;
505 let t2 = test_table_scan_with_name("t2")?;
506
507 let t1_schema = Arc::clone(t1.schema());
508 let t2_schema = Arc::clone(t2.schema());
509
510 let filter = Expr::eq(
512 col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
513 col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
514 )
515 .alias("t1.a + 1 = t2.a + 2");
516 let plan = LogicalPlanBuilder::from(t1)
517 .join_on(t2, JoinType::Left, Some(filter))?
518 .build()?;
519
520 assert_optimized_plan_equal!(
521 plan,
522 @r"
523 Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
524 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
525 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
526 "
527 )
528 }
529}