1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::{internal_err, DFSchema};
23use datafusion_common::{NullEquality, Result};
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27type EquijoinPredicate = (Expr, Expr);
29
30#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45 #[allow(missing_docs)]
46 pub fn new() -> Self {
47 Self {}
48 }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn name(&self) -> &str {
57 "extract_equijoin_predicate"
58 }
59
60 fn apply_order(&self) -> Option<ApplyOrder> {
61 Some(ApplyOrder::BottomUp)
62 }
63
64 fn rewrite(
65 &self,
66 plan: LogicalPlan,
67 _config: &dyn OptimizerConfig,
68 ) -> Result<Transformed<LogicalPlan>> {
69 match plan {
70 LogicalPlan::Join(Join {
71 left,
72 right,
73 mut on,
74 filter: Some(expr),
75 join_type,
76 join_constraint,
77 schema,
78 null_equality,
79 }) => {
80 let left_schema = left.schema();
81 let right_schema = right.schema();
82 let (equijoin_predicates, non_equijoin_expr) =
83 split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85 if on.is_empty()
94 && equijoin_predicates.is_empty()
95 && non_equijoin_expr.is_some()
96 {
97 let expr = non_equijoin_expr.clone().unwrap();
99 let (equijoin_predicates, non_equijoin_expr) =
100 split_is_not_distinct_from_and_other_join_predicate(
101 expr,
102 left_schema,
103 right_schema,
104 )?;
105
106 if !equijoin_predicates.is_empty() {
107 on.extend(equijoin_predicates);
108
109 return Ok(Transformed::yes(LogicalPlan::Join(Join {
110 left,
111 right,
112 on,
113 filter: non_equijoin_expr,
114 join_type,
115 join_constraint,
116 schema,
117 null_equality: NullEquality::NullEqualsNull,
120 })));
121 }
122 }
123
124 if !equijoin_predicates.is_empty() {
125 on.extend(equijoin_predicates);
126 Ok(Transformed::yes(LogicalPlan::Join(Join {
127 left,
128 right,
129 on,
130 filter: non_equijoin_expr,
131 join_type,
132 join_constraint,
133 schema,
134 null_equality,
135 })))
136 } else {
137 Ok(Transformed::no(LogicalPlan::Join(Join {
138 left,
139 right,
140 on,
141 filter: non_equijoin_expr,
142 join_type,
143 join_constraint,
144 schema,
145 null_equality,
146 })))
147 }
148 }
149 _ => Ok(Transformed::no(plan)),
150 }
151 }
152}
153
154fn split_eq_and_noneq_join_predicate(
174 filter: Expr,
175 left_schema: &DFSchema,
176 right_schema: &DFSchema,
177) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
178 split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq)
179}
180
181fn split_is_not_distinct_from_and_other_join_predicate(
207 filter: Expr,
208 left_schema: &DFSchema,
209 right_schema: &DFSchema,
210) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
211 split_op_and_other_join_predicates(
212 filter,
213 left_schema,
214 right_schema,
215 Operator::IsNotDistinctFrom,
216 )
217}
218
219fn split_op_and_other_join_predicates(
221 filter: Expr,
222 left_schema: &DFSchema,
223 right_schema: &DFSchema,
224 operator: Operator,
225) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
226 if !matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom) {
227 return internal_err!(
228 "split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \
229 but received: {:?}",
230 operator
231 );
232 }
233
234 let exprs = split_conjunction_owned(filter);
235
236 let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
238 let mut accum_filters: Vec<Expr> = vec![];
239 for expr in exprs {
240 match expr {
241 Expr::BinaryExpr(BinaryExpr {
242 ref left,
243 ref op,
244 ref right,
245 }) if *op == operator => {
246 let join_key_pair =
247 find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
248
249 if let Some((left_expr, right_expr)) = join_key_pair {
250 let left_expr_type = left_expr.get_type(left_schema)?;
251 let right_expr_type = right_expr.get_type(right_schema)?;
252
253 if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
254 accum_join_keys.push((left_expr, right_expr));
255 } else {
256 accum_filters.push(expr);
257 }
258 } else {
259 accum_filters.push(expr);
260 }
261 }
262 _ => accum_filters.push(expr),
263 }
264 }
265
266 let result_filter = accum_filters.into_iter().reduce(Expr::and);
267 Ok((accum_join_keys, result_filter))
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::assert_optimized_plan_eq_display_indent_snapshot;
274 use crate::test::*;
275 use arrow::datatypes::DataType;
276 use datafusion_expr::{
277 col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
278 };
279 use std::sync::Arc;
280
281 macro_rules! assert_optimized_plan_equal {
282 (
283 $plan:expr,
284 @ $expected:literal $(,)?
285 ) => {{
286 let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
287 assert_optimized_plan_eq_display_indent_snapshot!(
288 rule,
289 $plan,
290 @ $expected,
291 )
292 }};
293 }
294
295 #[test]
296 fn join_with_only_column_equi_predicate() -> Result<()> {
297 let t1 = test_table_scan_with_name("t1")?;
298 let t2 = test_table_scan_with_name("t2")?;
299
300 let plan = LogicalPlanBuilder::from(t1)
301 .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
302 .build()?;
303
304 assert_optimized_plan_equal!(
305 plan,
306 @r"
307 Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
308 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
309 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
310 "
311 )
312 }
313
314 #[test]
315 fn join_with_only_equi_expr_predicate() -> Result<()> {
316 let t1 = test_table_scan_with_name("t1")?;
317 let t2 = test_table_scan_with_name("t2")?;
318
319 let plan = LogicalPlanBuilder::from(t1)
320 .join_on(
321 t2,
322 JoinType::Left,
323 Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
324 )?
325 .build()?;
326
327 assert_optimized_plan_equal!(
328 plan,
329 @r"
330 Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
331 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
332 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
333 "
334 )
335 }
336
337 #[test]
338 fn join_with_only_none_equi_predicate() -> Result<()> {
339 let t1 = test_table_scan_with_name("t1")?;
340 let t2 = test_table_scan_with_name("t2")?;
341
342 let plan = LogicalPlanBuilder::from(t1)
343 .join_on(
344 t2,
345 JoinType::Left,
346 Some(
347 (col("t1.a") + lit(10i64))
348 .gt_eq(col("t2.a") * lit(2u32))
349 .and(col("t1.b").lt(lit(100i32))),
350 ),
351 )?
352 .build()?;
353
354 assert_optimized_plan_equal!(
355 plan,
356 @r"
357 Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
358 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
359 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
360 "
361 )
362 }
363
364 #[test]
365 fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
366 let t1 = test_table_scan_with_name("t1")?;
367 let t2 = test_table_scan_with_name("t2")?;
368
369 let plan = LogicalPlanBuilder::from(t1)
370 .join_with_expr_keys(
371 t2,
372 JoinType::Left,
373 (
374 vec![col("t1.a") + lit(11u32)],
375 vec![col("t2.a") * lit(2u32)],
376 ),
377 Some(
378 (col("t1.a") + lit(10i64))
379 .eq(col("t2.a") * lit(2u32))
380 .and(col("t1.b").lt(lit(100i32))),
381 ),
382 )?
383 .build()?;
384
385 assert_optimized_plan_equal!(
386 plan,
387 @r"
388 Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
389 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
390 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
391 "
392 )
393 }
394
395 #[test]
396 fn join_with_and_or_filter() -> Result<()> {
397 let t1 = test_table_scan_with_name("t1")?;
398 let t2 = test_table_scan_with_name("t2")?;
399
400 let plan = LogicalPlanBuilder::from(t1)
401 .join_on(
402 t2,
403 JoinType::Left,
404 Some(
405 col("t1.c")
406 .eq(col("t2.c"))
407 .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
408 .and(
409 col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
410 ),
411 ),
412 )?
413 .build()?;
414
415 assert_optimized_plan_equal!(
416 plan,
417 @r"
418 Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
419 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
420 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
421 "
422 )
423 }
424
425 #[test]
426 fn join_with_multiple_table() -> Result<()> {
427 let t1 = test_table_scan_with_name("t1")?;
428 let t2 = test_table_scan_with_name("t2")?;
429 let t3 = test_table_scan_with_name("t3")?;
430
431 let input = LogicalPlanBuilder::from(t2)
432 .join_on(
433 t3,
434 JoinType::Left,
435 Some(
436 col("t2.a")
437 .eq(col("t3.a"))
438 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
439 ),
440 )?
441 .build()?;
442 let plan = LogicalPlanBuilder::from(t1)
443 .join_on(
444 input,
445 JoinType::Left,
446 Some(
447 col("t1.a")
448 .eq(col("t2.a"))
449 .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
450 ),
451 )?
452 .build()?;
453
454 assert_optimized_plan_equal!(
455 plan,
456 @r"
457 Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
458 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
459 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
460 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
461 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
462 "
463 )
464 }
465
466 #[test]
467 fn join_with_multiple_table_and_eq_filter() -> Result<()> {
468 let t1 = test_table_scan_with_name("t1")?;
469 let t2 = test_table_scan_with_name("t2")?;
470 let t3 = test_table_scan_with_name("t3")?;
471
472 let input = LogicalPlanBuilder::from(t2)
473 .join_on(
474 t3,
475 JoinType::Left,
476 Some(
477 col("t2.a")
478 .eq(col("t3.a"))
479 .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
480 ),
481 )?
482 .build()?;
483 let plan = LogicalPlanBuilder::from(t1)
484 .join_on(
485 input,
486 JoinType::Left,
487 Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
488 )?
489 .build()?;
490
491 assert_optimized_plan_equal!(
492 plan,
493 @r"
494 Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
495 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
496 Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
497 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
498 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
499 "
500 )
501 }
502
503 #[test]
504 fn join_with_alias_filter() -> Result<()> {
505 let t1 = test_table_scan_with_name("t1")?;
506 let t2 = test_table_scan_with_name("t2")?;
507
508 let t1_schema = Arc::clone(t1.schema());
509 let t2_schema = Arc::clone(t2.schema());
510
511 let filter = Expr::eq(
513 col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
514 col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
515 )
516 .alias("t1.a + 1 = t2.a + 2");
517 let plan = LogicalPlanBuilder::from(t1)
518 .join_on(t2, JoinType::Left, Some(filter))?
519 .build()?;
520
521 assert_optimized_plan_equal!(
522 plan,
523 @r"
524 Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
525 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
526 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
527 "
528 )
529 }
530}