1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::{DFSchema, assert_or_internal_err};
23use datafusion_common::{NullEquality, Result};
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27type EquijoinPredicate = (Expr, Expr);
29
30#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45 #[expect(missing_docs)]
46 pub fn new() -> Self {
47 Self {}
48 }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn name(&self) -> &str {
57 "extract_equijoin_predicate"
58 }
59
60 fn apply_order(&self) -> Option<ApplyOrder> {
61 Some(ApplyOrder::BottomUp)
62 }
63
64 fn rewrite(
65 &self,
66 plan: LogicalPlan,
67 _config: &dyn OptimizerConfig,
68 ) -> Result<Transformed<LogicalPlan>> {
69 match plan {
70 LogicalPlan::Join(Join {
71 left,
72 right,
73 mut on,
74 filter: Some(expr),
75 join_type,
76 join_constraint,
77 schema,
78 null_equality,
79 null_aware,
80 }) => {
81 let left_schema = left.schema();
82 let right_schema = right.schema();
83 let (equijoin_predicates, non_equijoin_expr) =
84 split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
85
86 if on.is_empty()
95 && equijoin_predicates.is_empty()
96 && non_equijoin_expr.is_some()
97 {
98 let expr = non_equijoin_expr.clone().unwrap();
100 let (equijoin_predicates, non_equijoin_expr) =
101 split_is_not_distinct_from_and_other_join_predicate(
102 expr,
103 left_schema,
104 right_schema,
105 )?;
106
107 if !equijoin_predicates.is_empty() {
108 on.extend(equijoin_predicates);
109
110 return Ok(Transformed::yes(LogicalPlan::Join(Join {
111 left,
112 right,
113 on,
114 filter: non_equijoin_expr,
115 join_type,
116 join_constraint,
117 schema,
118 null_equality: NullEquality::NullEqualsNull,
121 null_aware,
122 })));
123 }
124 }
125
126 if !equijoin_predicates.is_empty() {
127 on.extend(equijoin_predicates);
128 Ok(Transformed::yes(LogicalPlan::Join(Join {
129 left,
130 right,
131 on,
132 filter: non_equijoin_expr,
133 join_type,
134 join_constraint,
135 schema,
136 null_equality,
137 null_aware,
138 })))
139 } else {
140 Ok(Transformed::no(LogicalPlan::Join(Join {
141 left,
142 right,
143 on,
144 filter: non_equijoin_expr,
145 join_type,
146 join_constraint,
147 schema,
148 null_equality,
149 null_aware,
150 })))
151 }
152 }
153 _ => Ok(Transformed::no(plan)),
154 }
155 }
156}
157
158fn split_eq_and_noneq_join_predicate(
178 filter: Expr,
179 left_schema: &DFSchema,
180 right_schema: &DFSchema,
181) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
182 split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq)
183}
184
185fn split_is_not_distinct_from_and_other_join_predicate(
211 filter: Expr,
212 left_schema: &DFSchema,
213 right_schema: &DFSchema,
214) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
215 split_op_and_other_join_predicates(
216 filter,
217 left_schema,
218 right_schema,
219 Operator::IsNotDistinctFrom,
220 )
221}
222
223fn split_op_and_other_join_predicates(
225 filter: Expr,
226 left_schema: &DFSchema,
227 right_schema: &DFSchema,
228 operator: Operator,
229) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
230 assert_or_internal_err!(
231 matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom),
232 "split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \
233 but received: {:?}",
234 operator
235 );
236
237 let exprs = split_conjunction_owned(filter);
238
239 let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
241 let mut accum_filters: Vec<Expr> = vec![];
242 for expr in exprs {
243 match expr {
244 Expr::BinaryExpr(BinaryExpr {
245 ref left,
246 ref op,
247 ref right,
248 }) if *op == operator => {
249 let join_key_pair =
250 find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
251
252 if let Some((left_expr, right_expr)) = join_key_pair {
253 let left_expr_type = left_expr.get_type(left_schema)?;
254 let right_expr_type = right_expr.get_type(right_schema)?;
255
256 if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
257 accum_join_keys.push((left_expr, right_expr));
258 } else {
259 accum_filters.push(expr);
260 }
261 } else {
262 accum_filters.push(expr);
263 }
264 }
265 _ => accum_filters.push(expr),
266 }
267 }
268
269 let result_filter = accum_filters.into_iter().reduce(Expr::and);
270 Ok((accum_join_keys, result_filter))
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::assert_optimized_plan_eq_display_indent_snapshot;
277 use crate::test::*;
278 use arrow::datatypes::DataType;
279 use datafusion_expr::{
280 JoinType, col, lit, logical_plan::builder::LogicalPlanBuilder,
281 };
282 use std::sync::Arc;
283
284 macro_rules! assert_optimized_plan_equal {
285 (
286 $plan:expr,
287 @ $expected:literal $(,)?
288 ) => {{
289 let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
290 assert_optimized_plan_eq_display_indent_snapshot!(
291 rule,
292 $plan,
293 @ $expected,
294 )
295 }};
296 }
297
298 #[test]
299 fn join_with_only_column_equi_predicate() -> Result<()> {
300 let t1 = test_table_scan_with_name("t1")?;
301 let t2 = test_table_scan_with_name("t2")?;
302
303 let plan = LogicalPlanBuilder::from(t1)
304 .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
305 .build()?;
306
307 assert_optimized_plan_equal!(
308 plan,
309 @r"
310 Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
311 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
312 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
313 "
314 )
315 }
316
317 #[test]
318 fn join_with_only_equi_expr_predicate() -> Result<()> {
319 let t1 = test_table_scan_with_name("t1")?;
320 let t2 = test_table_scan_with_name("t2")?;
321
322 let plan = LogicalPlanBuilder::from(t1)
323 .join_on(
324 t2,
325 JoinType::Left,
326 Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
327 )?
328 .build()?;
329
330 assert_optimized_plan_equal!(
331 plan,
332 @r"
333 Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
334 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
335 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
336 "
337 )
338 }
339
340 #[test]
341 fn join_with_only_none_equi_predicate() -> Result<()> {
342 let t1 = test_table_scan_with_name("t1")?;
343 let t2 = test_table_scan_with_name("t2")?;
344
345 let plan = LogicalPlanBuilder::from(t1)
346 .join_on(
347 t2,
348 JoinType::Left,
349 Some(
350 (col("t1.a") + lit(10i64))
351 .gt_eq(col("t2.a") * lit(2u32))
352 .and(col("t1.b").lt(lit(100i32))),
353 ),
354 )?
355 .build()?;
356
357 assert_optimized_plan_equal!(
358 plan,
359 @r"
360 Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
361 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
362 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
363 "
364 )
365 }
366
367 #[test]
368 fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
369 let t1 = test_table_scan_with_name("t1")?;
370 let t2 = test_table_scan_with_name("t2")?;
371
372 let plan = LogicalPlanBuilder::from(t1)
373 .join_with_expr_keys(
374 t2,
375 JoinType::Left,
376 (
377 vec![col("t1.a") + lit(11u32)],
378 vec![col("t2.a") * lit(2u32)],
379 ),
380 Some(
381 (col("t1.a") + lit(10i64))
382 .eq(col("t2.a") * lit(2u32))
383 .and(col("t1.b").lt(lit(100i32))),
384 ),
385 )?
386 .build()?;
387
388 assert_optimized_plan_equal!(
389 plan,
390 @r"
391 Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
392 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
393 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
394 "
395 )
396 }
397
398 #[test]
399 fn join_with_and_or_filter() -> Result<()> {
400 let t1 = test_table_scan_with_name("t1")?;
401 let t2 = test_table_scan_with_name("t2")?;
402
403 let plan = LogicalPlanBuilder::from(t1)
404 .join_on(
405 t2,
406 JoinType::Left,
407 Some(
408 col("t1.c")
409 .eq(col("t2.c"))
410 .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
411 .and(
412 col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
413 ),
414 ),
415 )?
416 .build()?;
417
418 assert_optimized_plan_equal!(
419 plan,
420 @r"
421 Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
422 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
423 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
424 "
425 )
426 }
427
428 #[test]
429 fn join_with_multiple_table() -> Result<()> {
430 let t1 = test_table_scan_with_name("t1")?;
431 let t2 = test_table_scan_with_name("t2")?;
432 let t3 = test_table_scan_with_name("t3")?;
433
434 let input = LogicalPlanBuilder::from(t2)
435 .join_on(
436 t3,
437 JoinType::Left,
438 Some(
439 col("t2.a")
440 .eq(col("t3.a"))
441 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
442 ),
443 )?
444 .build()?;
445 let plan = LogicalPlanBuilder::from(t1)
446 .join_on(
447 input,
448 JoinType::Left,
449 Some(
450 col("t1.a")
451 .eq(col("t2.a"))
452 .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
453 ),
454 )?
455 .build()?;
456
457 assert_optimized_plan_equal!(
458 plan,
459 @r"
460 Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
461 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
462 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
463 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
464 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
465 "
466 )
467 }
468
469 #[test]
470 fn join_with_multiple_table_and_eq_filter() -> Result<()> {
471 let t1 = test_table_scan_with_name("t1")?;
472 let t2 = test_table_scan_with_name("t2")?;
473 let t3 = test_table_scan_with_name("t3")?;
474
475 let input = LogicalPlanBuilder::from(t2)
476 .join_on(
477 t3,
478 JoinType::Left,
479 Some(
480 col("t2.a")
481 .eq(col("t3.a"))
482 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
483 ),
484 )?
485 .build()?;
486 let plan = LogicalPlanBuilder::from(t1)
487 .join_on(
488 input,
489 JoinType::Left,
490 Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
491 )?
492 .build()?;
493
494 assert_optimized_plan_equal!(
495 plan,
496 @r"
497 Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
498 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
499 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
500 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
501 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
502 "
503 )
504 }
505
506 #[test]
507 fn join_with_alias_filter() -> Result<()> {
508 let t1 = test_table_scan_with_name("t1")?;
509 let t2 = test_table_scan_with_name("t2")?;
510
511 let t1_schema = Arc::clone(t1.schema());
512 let t2_schema = Arc::clone(t2.schema());
513
514 let filter = Expr::eq(
516 col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
517 col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
518 )
519 .alias("t1.a + 1 = t2.a + 2");
520 let plan = LogicalPlanBuilder::from(t1)
521 .join_on(t2, JoinType::Left, Some(filter))?
522 .build()?;
523
524 assert_optimized_plan_equal!(
525 plan,
526 @r"
527 Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
528 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
529 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
530 "
531 )
532 }
533}