Skip to main content

datafusion_optimizer/
extract_equijoin_predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::{DFSchema, assert_or_internal_err};
23use datafusion_common::{NullEquality, Result};
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27// equijoin predicate
28type EquijoinPredicate = (Expr, Expr);
29
30/// Optimizer that splits conjunctive join predicates into equijoin
31/// predicates and (other) filter predicates.
32///
33/// Join algorithms are often highly optimized for equality predicates such as `x = y`,
34/// often called `equijoin` predicates, so it is important to locate such predicates
35/// and treat them specially.
36///
37/// For example, `SELECT ... FROM A JOIN B ON (A.x = B.y AND B.z > 50)`
38/// has one equijoin predicate (`A.x = B.y`) and one filter predicate (`B.z > 50`).
39/// See [find_valid_equijoin_key_pair] for more information on what predicates
40/// are considered equijoins.
41#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45    #[expect(missing_docs)]
46    pub fn new() -> Self {
47        Self {}
48    }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn name(&self) -> &str {
57        "extract_equijoin_predicate"
58    }
59
60    fn apply_order(&self) -> Option<ApplyOrder> {
61        Some(ApplyOrder::BottomUp)
62    }
63
64    fn rewrite(
65        &self,
66        plan: LogicalPlan,
67        _config: &dyn OptimizerConfig,
68    ) -> Result<Transformed<LogicalPlan>> {
69        match plan {
70            LogicalPlan::Join(Join {
71                left,
72                right,
73                mut on,
74                filter: Some(expr),
75                join_type,
76                join_constraint,
77                schema,
78                null_equality,
79                null_aware,
80            }) => {
81                let left_schema = left.schema();
82                let right_schema = right.schema();
83                let (equijoin_predicates, non_equijoin_expr) =
84                    split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
85
86                // Equi-join operators like HashJoin support a special behavior
87                // that evaluates `NULL = NULL` as true instead of NULL. Therefore,
88                // we transform `t1.c1 IS NOT DISTINCT FROM t2.c1` into an equi-join
89                // and set the `NullEquality` configuration in the join operator.
90                // This allows certain queries to use Hash Join instead of
91                // Nested Loop Join, resulting in better performance.
92                //
93                // Only convert when there are NO equijoin predicates, to be conservative.
94                if on.is_empty()
95                    && equijoin_predicates.is_empty()
96                    && non_equijoin_expr.is_some()
97                {
98                    // SAFETY: checked in the outer `if`
99                    let expr = non_equijoin_expr.clone().unwrap();
100                    let (equijoin_predicates, non_equijoin_expr) =
101                        split_is_not_distinct_from_and_other_join_predicate(
102                            expr,
103                            left_schema,
104                            right_schema,
105                        )?;
106
107                    if !equijoin_predicates.is_empty() {
108                        on.extend(equijoin_predicates);
109
110                        return Ok(Transformed::yes(LogicalPlan::Join(Join {
111                            left,
112                            right,
113                            on,
114                            filter: non_equijoin_expr,
115                            join_type,
116                            join_constraint,
117                            schema,
118                            // According to `is not distinct from`'s semantics, it's
119                            // safe to override it
120                            null_equality: NullEquality::NullEqualsNull,
121                            null_aware,
122                        })));
123                    }
124                }
125
126                if !equijoin_predicates.is_empty() {
127                    on.extend(equijoin_predicates);
128                    Ok(Transformed::yes(LogicalPlan::Join(Join {
129                        left,
130                        right,
131                        on,
132                        filter: non_equijoin_expr,
133                        join_type,
134                        join_constraint,
135                        schema,
136                        null_equality,
137                        null_aware,
138                    })))
139                } else {
140                    Ok(Transformed::no(LogicalPlan::Join(Join {
141                        left,
142                        right,
143                        on,
144                        filter: non_equijoin_expr,
145                        join_type,
146                        join_constraint,
147                        schema,
148                        null_equality,
149                        null_aware,
150                    })))
151                }
152            }
153            _ => Ok(Transformed::no(plan)),
154        }
155    }
156}
157
158/// Splits an ANDed filter expression into equijoin predicates and remaining filters.
159/// Returns all equijoin predicates and the remaining filters combined with AND.
160///
161/// # Example
162///
163/// For the expression `a.id = b.id AND a.x > 10 AND b.x > b.id`, this function will extract `a.id = b.id` as an equijoin predicate.
164///
165/// It first splits the ANDed sub-expressions:
166/// - expr1: a.id = b.id
167/// - expr2: a.x > 10
168/// - expr3: b.x > b.id
169///
170/// Then, it filters out the equijoin predicates and collects the non-equality expressions.
171/// The equijoin condition is:
172/// - It is an equality expression like `lhs == rhs`
173/// - All column references in `lhs` are from the left schema, and all in `rhs` are from the right schema
174///
175/// According to the above rule, `expr1` is the equijoin predicate, while `expr2` and `expr3` are not.
176/// The function returns Ok(\[expr1\], Some(expr2 AND expr3))
177fn split_eq_and_noneq_join_predicate(
178    filter: Expr,
179    left_schema: &DFSchema,
180    right_schema: &DFSchema,
181) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
182    split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq)
183}
184
185/// See `split_eq_and_noneq_join_predicate`'s comment for the idea. This function
186/// is splitting out `is not distinct from` expressions instead of equal exprs.
187/// The `is not distinct from` exprs will be return as `EquijoinPredicate`.
188///
189/// # Example
190/// - Input: `a.id IS NOT DISTINCT FROM b.id AND a.x > 10 AND b.x > b.id`
191/// - Output from this splitter: `Ok([a.id, b.id], Some((a.x > 10) AND (b.x > b.id)))`
192///
193/// # Note
194/// Caller should be cautious -- `is not distinct from` is not equivalent to an
195/// equal expression; the caller is responsible for correctly setting the
196/// `nulls equals nulls` property in the join operator (if it supports it) to
197/// make the transformation valid.
198///
199/// For the above example: in downstream, a valid plan that uses the extracted
200/// equijoin keys should look like:
201///
202/// HashJoin
203/// - on: `a.id = b.id` (equality)
204/// - join_filter: `(a.x > 10) AND (b.x > b.id)`
205/// - nulls_equals_null: `true`
206///
207/// This reflects that `IS NOT DISTINCT FROM` treats `NULL = NULL` as true and
208/// thus requires setting `NullEquality::NullEqualsNull` in the join operator to
209/// preserve semantics while enabling an equi-join implementation (e.g., HashJoin).
210fn split_is_not_distinct_from_and_other_join_predicate(
211    filter: Expr,
212    left_schema: &DFSchema,
213    right_schema: &DFSchema,
214) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
215    split_op_and_other_join_predicates(
216        filter,
217        left_schema,
218        right_schema,
219        Operator::IsNotDistinctFrom,
220    )
221}
222
223/// See comments in `split_eq_and_noneq_join_predicate` for details.
224fn split_op_and_other_join_predicates(
225    filter: Expr,
226    left_schema: &DFSchema,
227    right_schema: &DFSchema,
228    operator: Operator,
229) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
230    assert_or_internal_err!(
231        matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom),
232        "split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \
233        but received: {:?}",
234        operator
235    );
236
237    let exprs = split_conjunction_owned(filter);
238
239    // Treat 'is not distinct from' comparison as join key in equal joins
240    let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
241    let mut accum_filters: Vec<Expr> = vec![];
242    for expr in exprs {
243        match expr {
244            Expr::BinaryExpr(BinaryExpr {
245                ref left,
246                ref op,
247                ref right,
248            }) if *op == operator => {
249                let join_key_pair =
250                    find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
251
252                if let Some((left_expr, right_expr)) = join_key_pair {
253                    let left_expr_type = left_expr.get_type(left_schema)?;
254                    let right_expr_type = right_expr.get_type(right_schema)?;
255
256                    if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
257                        accum_join_keys.push((left_expr, right_expr));
258                    } else {
259                        accum_filters.push(expr);
260                    }
261                } else {
262                    accum_filters.push(expr);
263                }
264            }
265            _ => accum_filters.push(expr),
266        }
267    }
268
269    let result_filter = accum_filters.into_iter().reduce(Expr::and);
270    Ok((accum_join_keys, result_filter))
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::assert_optimized_plan_eq_display_indent_snapshot;
277    use crate::test::*;
278    use arrow::datatypes::DataType;
279    use datafusion_expr::{
280        JoinType, col, lit, logical_plan::builder::LogicalPlanBuilder,
281    };
282    use std::sync::Arc;
283
284    macro_rules! assert_optimized_plan_equal {
285        (
286            $plan:expr,
287            @ $expected:literal $(,)?
288        ) => {{
289            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
290            assert_optimized_plan_eq_display_indent_snapshot!(
291                rule,
292                $plan,
293                @ $expected,
294            )
295        }};
296    }
297
298    #[test]
299    fn join_with_only_column_equi_predicate() -> Result<()> {
300        let t1 = test_table_scan_with_name("t1")?;
301        let t2 = test_table_scan_with_name("t2")?;
302
303        let plan = LogicalPlanBuilder::from(t1)
304            .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
305            .build()?;
306
307        assert_optimized_plan_equal!(
308            plan,
309            @r"
310        Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
311          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
312          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
313        "
314        )
315    }
316
317    #[test]
318    fn join_with_only_equi_expr_predicate() -> Result<()> {
319        let t1 = test_table_scan_with_name("t1")?;
320        let t2 = test_table_scan_with_name("t2")?;
321
322        let plan = LogicalPlanBuilder::from(t1)
323            .join_on(
324                t2,
325                JoinType::Left,
326                Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
327            )?
328            .build()?;
329
330        assert_optimized_plan_equal!(
331            plan,
332            @r"
333        Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
334          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
335          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
336        "
337        )
338    }
339
340    #[test]
341    fn join_with_only_none_equi_predicate() -> Result<()> {
342        let t1 = test_table_scan_with_name("t1")?;
343        let t2 = test_table_scan_with_name("t2")?;
344
345        let plan = LogicalPlanBuilder::from(t1)
346            .join_on(
347                t2,
348                JoinType::Left,
349                Some(
350                    (col("t1.a") + lit(10i64))
351                        .gt_eq(col("t2.a") * lit(2u32))
352                        .and(col("t1.b").lt(lit(100i32))),
353                ),
354            )?
355            .build()?;
356
357        assert_optimized_plan_equal!(
358            plan,
359            @r"
360        Left Join:  Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
361          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
362          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
363        "
364        )
365    }
366
367    #[test]
368    fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
369        let t1 = test_table_scan_with_name("t1")?;
370        let t2 = test_table_scan_with_name("t2")?;
371
372        let plan = LogicalPlanBuilder::from(t1)
373            .join_with_expr_keys(
374                t2,
375                JoinType::Left,
376                (
377                    vec![col("t1.a") + lit(11u32)],
378                    vec![col("t2.a") * lit(2u32)],
379                ),
380                Some(
381                    (col("t1.a") + lit(10i64))
382                        .eq(col("t2.a") * lit(2u32))
383                        .and(col("t1.b").lt(lit(100i32))),
384                ),
385            )?
386            .build()?;
387
388        assert_optimized_plan_equal!(
389            plan,
390            @r"
391        Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
392          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
393          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
394        "
395        )
396    }
397
398    #[test]
399    fn join_with_and_or_filter() -> Result<()> {
400        let t1 = test_table_scan_with_name("t1")?;
401        let t2 = test_table_scan_with_name("t2")?;
402
403        let plan = LogicalPlanBuilder::from(t1)
404            .join_on(
405                t2,
406                JoinType::Left,
407                Some(
408                    col("t1.c")
409                        .eq(col("t2.c"))
410                        .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
411                        .and(
412                            col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
413                        ),
414                ),
415            )?
416            .build()?;
417
418        assert_optimized_plan_equal!(
419            plan,
420            @r"
421        Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
422          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
423          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
424        "
425        )
426    }
427
428    #[test]
429    fn join_with_multiple_table() -> Result<()> {
430        let t1 = test_table_scan_with_name("t1")?;
431        let t2 = test_table_scan_with_name("t2")?;
432        let t3 = test_table_scan_with_name("t3")?;
433
434        let input = LogicalPlanBuilder::from(t2)
435            .join_on(
436                t3,
437                JoinType::Left,
438                Some(
439                    col("t2.a")
440                        .eq(col("t3.a"))
441                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
442                ),
443            )?
444            .build()?;
445        let plan = LogicalPlanBuilder::from(t1)
446            .join_on(
447                input,
448                JoinType::Left,
449                Some(
450                    col("t1.a")
451                        .eq(col("t2.a"))
452                        .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
453                ),
454            )?
455            .build()?;
456
457        assert_optimized_plan_equal!(
458            plan,
459            @r"
460        Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
461          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
462          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
463            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
464            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
465        "
466        )
467    }
468
469    #[test]
470    fn join_with_multiple_table_and_eq_filter() -> Result<()> {
471        let t1 = test_table_scan_with_name("t1")?;
472        let t2 = test_table_scan_with_name("t2")?;
473        let t3 = test_table_scan_with_name("t3")?;
474
475        let input = LogicalPlanBuilder::from(t2)
476            .join_on(
477                t3,
478                JoinType::Left,
479                Some(
480                    col("t2.a")
481                        .eq(col("t3.a"))
482                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
483                ),
484            )?
485            .build()?;
486        let plan = LogicalPlanBuilder::from(t1)
487            .join_on(
488                input,
489                JoinType::Left,
490                Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
491            )?
492            .build()?;
493
494        assert_optimized_plan_equal!(
495            plan,
496            @r"
497        Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
498          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
499          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
500            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
501            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
502        "
503        )
504    }
505
506    #[test]
507    fn join_with_alias_filter() -> Result<()> {
508        let t1 = test_table_scan_with_name("t1")?;
509        let t2 = test_table_scan_with_name("t2")?;
510
511        let t1_schema = Arc::clone(t1.schema());
512        let t2_schema = Arc::clone(t2.schema());
513
514        // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2
515        let filter = Expr::eq(
516            col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
517            col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
518        )
519        .alias("t1.a + 1 = t2.a + 2");
520        let plan = LogicalPlanBuilder::from(t1)
521            .join_on(t2, JoinType::Left, Some(filter))?
522            .build()?;
523
524        assert_optimized_plan_equal!(
525            plan,
526            @r"
527        Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
528          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
529          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
530        "
531        )
532    }
533}