datafusion_optimizer/
extract_equijoin_predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::{DFSchema, assert_or_internal_err};
23use datafusion_common::{NullEquality, Result};
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27// equijoin predicate
28type EquijoinPredicate = (Expr, Expr);
29
30/// Optimizer that splits conjunctive join predicates into equijoin
31/// predicates and (other) filter predicates.
32///
33/// Join algorithms are often highly optimized for equality predicates such as `x = y`,
34/// often called `equijoin` predicates, so it is important to locate such predicates
35/// and treat them specially.
36///
37/// For example, `SELECT ... FROM A JOIN B ON (A.x = B.y AND B.z > 50)`
38/// has one equijoin predicate (`A.x = B.y`) and one filter predicate (`B.z > 50`).
39/// See [find_valid_equijoin_key_pair] for more information on what predicates
40/// are considered equijoins.
41#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45    #[expect(missing_docs)]
46    pub fn new() -> Self {
47        Self {}
48    }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn name(&self) -> &str {
57        "extract_equijoin_predicate"
58    }
59
60    fn apply_order(&self) -> Option<ApplyOrder> {
61        Some(ApplyOrder::BottomUp)
62    }
63
64    fn rewrite(
65        &self,
66        plan: LogicalPlan,
67        _config: &dyn OptimizerConfig,
68    ) -> Result<Transformed<LogicalPlan>> {
69        match plan {
70            LogicalPlan::Join(Join {
71                left,
72                right,
73                mut on,
74                filter: Some(expr),
75                join_type,
76                join_constraint,
77                schema,
78                null_equality,
79            }) => {
80                let left_schema = left.schema();
81                let right_schema = right.schema();
82                let (equijoin_predicates, non_equijoin_expr) =
83                    split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85                // Equi-join operators like HashJoin support a special behavior
86                // that evaluates `NULL = NULL` as true instead of NULL. Therefore,
87                // we transform `t1.c1 IS NOT DISTINCT FROM t2.c1` into an equi-join
88                // and set the `NullEquality` configuration in the join operator.
89                // This allows certain queries to use Hash Join instead of
90                // Nested Loop Join, resulting in better performance.
91                //
92                // Only convert when there are NO equijoin predicates, to be conservative.
93                if on.is_empty()
94                    && equijoin_predicates.is_empty()
95                    && non_equijoin_expr.is_some()
96                {
97                    // SAFETY: checked in the outer `if`
98                    let expr = non_equijoin_expr.clone().unwrap();
99                    let (equijoin_predicates, non_equijoin_expr) =
100                        split_is_not_distinct_from_and_other_join_predicate(
101                            expr,
102                            left_schema,
103                            right_schema,
104                        )?;
105
106                    if !equijoin_predicates.is_empty() {
107                        on.extend(equijoin_predicates);
108
109                        return Ok(Transformed::yes(LogicalPlan::Join(Join {
110                            left,
111                            right,
112                            on,
113                            filter: non_equijoin_expr,
114                            join_type,
115                            join_constraint,
116                            schema,
117                            // According to `is not distinct from`'s semantics, it's
118                            // safe to override it
119                            null_equality: NullEquality::NullEqualsNull,
120                        })));
121                    }
122                }
123
124                if !equijoin_predicates.is_empty() {
125                    on.extend(equijoin_predicates);
126                    Ok(Transformed::yes(LogicalPlan::Join(Join {
127                        left,
128                        right,
129                        on,
130                        filter: non_equijoin_expr,
131                        join_type,
132                        join_constraint,
133                        schema,
134                        null_equality,
135                    })))
136                } else {
137                    Ok(Transformed::no(LogicalPlan::Join(Join {
138                        left,
139                        right,
140                        on,
141                        filter: non_equijoin_expr,
142                        join_type,
143                        join_constraint,
144                        schema,
145                        null_equality,
146                    })))
147                }
148            }
149            _ => Ok(Transformed::no(plan)),
150        }
151    }
152}
153
154/// Splits an ANDed filter expression into equijoin predicates and remaining filters.
155/// Returns all equijoin predicates and the remaining filters combined with AND.
156///
157/// # Example
158///
159/// For the expression `a.id = b.id AND a.x > 10 AND b.x > b.id`, this function will extract `a.id = b.id` as an equijoin predicate.
160///
161/// It first splits the ANDed sub-expressions:
162/// - expr1: a.id = b.id
163/// - expr2: a.x > 10
164/// - expr3: b.x > b.id
165///
166/// Then, it filters out the equijoin predicates and collects the non-equality expressions.
167/// The equijoin condition is:
168/// - It is an equality expression like `lhs == rhs`
169/// - All column references in `lhs` are from the left schema, and all in `rhs` are from the right schema
170///
171/// According to the above rule, `expr1` is the equijoin predicate, while `expr2` and `expr3` are not.
172/// The function returns Ok(\[expr1\], Some(expr2 AND expr3))
173fn split_eq_and_noneq_join_predicate(
174    filter: Expr,
175    left_schema: &DFSchema,
176    right_schema: &DFSchema,
177) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
178    split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq)
179}
180
181/// See `split_eq_and_noneq_join_predicate`'s comment for the idea. This function
182/// is splitting out `is not distinct from` expressions instead of equal exprs.
183/// The `is not distinct from` exprs will be return as `EquijoinPredicate`.
184///
185/// # Example
186/// - Input: `a.id IS NOT DISTINCT FROM b.id AND a.x > 10 AND b.x > b.id`
187/// - Output from this splitter: `Ok([a.id, b.id], Some((a.x > 10) AND (b.x > b.id)))`
188///
189/// # Note
190/// Caller should be cautious -- `is not distinct from` is not equivalent to an
191/// equal expression; the caller is responsible for correctly setting the
192/// `nulls equals nulls` property in the join operator (if it supports it) to
193/// make the transformation valid.
194///
195/// For the above example: in downstream, a valid plan that uses the extracted
196/// equijoin keys should look like:
197///
198/// HashJoin
199/// - on: `a.id = b.id` (equality)
200/// - join_filter: `(a.x > 10) AND (b.x > b.id)`
201/// - nulls_equals_null: `true`
202///
203/// This reflects that `IS NOT DISTINCT FROM` treats `NULL = NULL` as true and
204/// thus requires setting `NullEquality::NullEqualsNull` in the join operator to
205/// preserve semantics while enabling an equi-join implementation (e.g., HashJoin).
206fn split_is_not_distinct_from_and_other_join_predicate(
207    filter: Expr,
208    left_schema: &DFSchema,
209    right_schema: &DFSchema,
210) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
211    split_op_and_other_join_predicates(
212        filter,
213        left_schema,
214        right_schema,
215        Operator::IsNotDistinctFrom,
216    )
217}
218
219/// See comments in `split_eq_and_noneq_join_predicate` for details.
220fn split_op_and_other_join_predicates(
221    filter: Expr,
222    left_schema: &DFSchema,
223    right_schema: &DFSchema,
224    operator: Operator,
225) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
226    assert_or_internal_err!(
227        matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom),
228        "split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \
229        but received: {:?}",
230        operator
231    );
232
233    let exprs = split_conjunction_owned(filter);
234
235    // Treat 'is not distinct from' comparison as join key in equal joins
236    let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
237    let mut accum_filters: Vec<Expr> = vec![];
238    for expr in exprs {
239        match expr {
240            Expr::BinaryExpr(BinaryExpr {
241                ref left,
242                ref op,
243                ref right,
244            }) if *op == operator => {
245                let join_key_pair =
246                    find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
247
248                if let Some((left_expr, right_expr)) = join_key_pair {
249                    let left_expr_type = left_expr.get_type(left_schema)?;
250                    let right_expr_type = right_expr.get_type(right_schema)?;
251
252                    if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
253                        accum_join_keys.push((left_expr, right_expr));
254                    } else {
255                        accum_filters.push(expr);
256                    }
257                } else {
258                    accum_filters.push(expr);
259                }
260            }
261            _ => accum_filters.push(expr),
262        }
263    }
264
265    let result_filter = accum_filters.into_iter().reduce(Expr::and);
266    Ok((accum_join_keys, result_filter))
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::assert_optimized_plan_eq_display_indent_snapshot;
273    use crate::test::*;
274    use arrow::datatypes::DataType;
275    use datafusion_expr::{
276        JoinType, col, lit, logical_plan::builder::LogicalPlanBuilder,
277    };
278    use std::sync::Arc;
279
280    macro_rules! assert_optimized_plan_equal {
281        (
282            $plan:expr,
283            @ $expected:literal $(,)?
284        ) => {{
285            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
286            assert_optimized_plan_eq_display_indent_snapshot!(
287                rule,
288                $plan,
289                @ $expected,
290            )
291        }};
292    }
293
294    #[test]
295    fn join_with_only_column_equi_predicate() -> Result<()> {
296        let t1 = test_table_scan_with_name("t1")?;
297        let t2 = test_table_scan_with_name("t2")?;
298
299        let plan = LogicalPlanBuilder::from(t1)
300            .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
301            .build()?;
302
303        assert_optimized_plan_equal!(
304            plan,
305            @r"
306        Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
307          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
308          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
309        "
310        )
311    }
312
313    #[test]
314    fn join_with_only_equi_expr_predicate() -> Result<()> {
315        let t1 = test_table_scan_with_name("t1")?;
316        let t2 = test_table_scan_with_name("t2")?;
317
318        let plan = LogicalPlanBuilder::from(t1)
319            .join_on(
320                t2,
321                JoinType::Left,
322                Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
323            )?
324            .build()?;
325
326        assert_optimized_plan_equal!(
327            plan,
328            @r"
329        Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
330          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
331          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
332        "
333        )
334    }
335
336    #[test]
337    fn join_with_only_none_equi_predicate() -> Result<()> {
338        let t1 = test_table_scan_with_name("t1")?;
339        let t2 = test_table_scan_with_name("t2")?;
340
341        let plan = LogicalPlanBuilder::from(t1)
342            .join_on(
343                t2,
344                JoinType::Left,
345                Some(
346                    (col("t1.a") + lit(10i64))
347                        .gt_eq(col("t2.a") * lit(2u32))
348                        .and(col("t1.b").lt(lit(100i32))),
349                ),
350            )?
351            .build()?;
352
353        assert_optimized_plan_equal!(
354            plan,
355            @r"
356        Left Join:  Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
357          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
358          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
359        "
360        )
361    }
362
363    #[test]
364    fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
365        let t1 = test_table_scan_with_name("t1")?;
366        let t2 = test_table_scan_with_name("t2")?;
367
368        let plan = LogicalPlanBuilder::from(t1)
369            .join_with_expr_keys(
370                t2,
371                JoinType::Left,
372                (
373                    vec![col("t1.a") + lit(11u32)],
374                    vec![col("t2.a") * lit(2u32)],
375                ),
376                Some(
377                    (col("t1.a") + lit(10i64))
378                        .eq(col("t2.a") * lit(2u32))
379                        .and(col("t1.b").lt(lit(100i32))),
380                ),
381            )?
382            .build()?;
383
384        assert_optimized_plan_equal!(
385            plan,
386            @r"
387        Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
388          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
389          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
390        "
391        )
392    }
393
394    #[test]
395    fn join_with_and_or_filter() -> Result<()> {
396        let t1 = test_table_scan_with_name("t1")?;
397        let t2 = test_table_scan_with_name("t2")?;
398
399        let plan = LogicalPlanBuilder::from(t1)
400            .join_on(
401                t2,
402                JoinType::Left,
403                Some(
404                    col("t1.c")
405                        .eq(col("t2.c"))
406                        .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
407                        .and(
408                            col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
409                        ),
410                ),
411            )?
412            .build()?;
413
414        assert_optimized_plan_equal!(
415            plan,
416            @r"
417        Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
418          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
419          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
420        "
421        )
422    }
423
424    #[test]
425    fn join_with_multiple_table() -> Result<()> {
426        let t1 = test_table_scan_with_name("t1")?;
427        let t2 = test_table_scan_with_name("t2")?;
428        let t3 = test_table_scan_with_name("t3")?;
429
430        let input = LogicalPlanBuilder::from(t2)
431            .join_on(
432                t3,
433                JoinType::Left,
434                Some(
435                    col("t2.a")
436                        .eq(col("t3.a"))
437                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
438                ),
439            )?
440            .build()?;
441        let plan = LogicalPlanBuilder::from(t1)
442            .join_on(
443                input,
444                JoinType::Left,
445                Some(
446                    col("t1.a")
447                        .eq(col("t2.a"))
448                        .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
449                ),
450            )?
451            .build()?;
452
453        assert_optimized_plan_equal!(
454            plan,
455            @r"
456        Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
457          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
458          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
459            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
460            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
461        "
462        )
463    }
464
465    #[test]
466    fn join_with_multiple_table_and_eq_filter() -> Result<()> {
467        let t1 = test_table_scan_with_name("t1")?;
468        let t2 = test_table_scan_with_name("t2")?;
469        let t3 = test_table_scan_with_name("t3")?;
470
471        let input = LogicalPlanBuilder::from(t2)
472            .join_on(
473                t3,
474                JoinType::Left,
475                Some(
476                    col("t2.a")
477                        .eq(col("t3.a"))
478                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
479                ),
480            )?
481            .build()?;
482        let plan = LogicalPlanBuilder::from(t1)
483            .join_on(
484                input,
485                JoinType::Left,
486                Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
487            )?
488            .build()?;
489
490        assert_optimized_plan_equal!(
491            plan,
492            @r"
493        Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
494          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
495          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
496            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
497            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
498        "
499        )
500    }
501
502    #[test]
503    fn join_with_alias_filter() -> Result<()> {
504        let t1 = test_table_scan_with_name("t1")?;
505        let t2 = test_table_scan_with_name("t2")?;
506
507        let t1_schema = Arc::clone(t1.schema());
508        let t2_schema = Arc::clone(t2.schema());
509
510        // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2
511        let filter = Expr::eq(
512            col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
513            col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
514        )
515        .alias("t1.a + 1 = t2.a + 2");
516        let plan = LogicalPlanBuilder::from(t1)
517            .join_on(t2, JoinType::Left, Some(filter))?
518            .build()?;
519
520        assert_optimized_plan_equal!(
521            plan,
522            @r"
523        Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
524          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
525          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
526        "
527        )
528    }
529}