datafusion_optimizer/
extract_equijoin_predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::{internal_err, DFSchema};
23use datafusion_common::{NullEquality, Result};
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27// equijoin predicate
28type EquijoinPredicate = (Expr, Expr);
29
30/// Optimizer that splits conjunctive join predicates into equijoin
31/// predicates and (other) filter predicates.
32///
33/// Join algorithms are often highly optimized for equality predicates such as `x = y`,
34/// often called `equijoin` predicates, so it is important to locate such predicates
35/// and treat them specially.
36///
37/// For example, `SELECT ... FROM A JOIN B ON (A.x = B.y AND B.z > 50)`
38/// has one equijoin predicate (`A.x = B.y`) and one filter predicate (`B.z > 50`).
39/// See [find_valid_equijoin_key_pair] for more information on what predicates
40/// are considered equijoins.
41#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45    #[allow(missing_docs)]
46    pub fn new() -> Self {
47        Self {}
48    }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn name(&self) -> &str {
57        "extract_equijoin_predicate"
58    }
59
60    fn apply_order(&self) -> Option<ApplyOrder> {
61        Some(ApplyOrder::BottomUp)
62    }
63
64    fn rewrite(
65        &self,
66        plan: LogicalPlan,
67        _config: &dyn OptimizerConfig,
68    ) -> Result<Transformed<LogicalPlan>> {
69        match plan {
70            LogicalPlan::Join(Join {
71                left,
72                right,
73                mut on,
74                filter: Some(expr),
75                join_type,
76                join_constraint,
77                schema,
78                null_equality,
79            }) => {
80                let left_schema = left.schema();
81                let right_schema = right.schema();
82                let (equijoin_predicates, non_equijoin_expr) =
83                    split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85                // Equi-join operators like HashJoin support a special behavior
86                // that evaluates `NULL = NULL` as true instead of NULL. Therefore,
87                // we transform `t1.c1 IS NOT DISTINCT FROM t2.c1` into an equi-join
88                // and set the `NullEquality` configuration in the join operator.
89                // This allows certain queries to use Hash Join instead of
90                // Nested Loop Join, resulting in better performance.
91                //
92                // Only convert when there are NO equijoin predicates, to be conservative.
93                if on.is_empty()
94                    && equijoin_predicates.is_empty()
95                    && non_equijoin_expr.is_some()
96                {
97                    // SAFETY: checked in the outer `if`
98                    let expr = non_equijoin_expr.clone().unwrap();
99                    let (equijoin_predicates, non_equijoin_expr) =
100                        split_is_not_distinct_from_and_other_join_predicate(
101                            expr,
102                            left_schema,
103                            right_schema,
104                        )?;
105
106                    if !equijoin_predicates.is_empty() {
107                        on.extend(equijoin_predicates);
108
109                        return Ok(Transformed::yes(LogicalPlan::Join(Join {
110                            left,
111                            right,
112                            on,
113                            filter: non_equijoin_expr,
114                            join_type,
115                            join_constraint,
116                            schema,
117                            // According to `is not distinct from`'s semantics, it's
118                            // safe to override it
119                            null_equality: NullEquality::NullEqualsNull,
120                        })));
121                    }
122                }
123
124                if !equijoin_predicates.is_empty() {
125                    on.extend(equijoin_predicates);
126                    Ok(Transformed::yes(LogicalPlan::Join(Join {
127                        left,
128                        right,
129                        on,
130                        filter: non_equijoin_expr,
131                        join_type,
132                        join_constraint,
133                        schema,
134                        null_equality,
135                    })))
136                } else {
137                    Ok(Transformed::no(LogicalPlan::Join(Join {
138                        left,
139                        right,
140                        on,
141                        filter: non_equijoin_expr,
142                        join_type,
143                        join_constraint,
144                        schema,
145                        null_equality,
146                    })))
147                }
148            }
149            _ => Ok(Transformed::no(plan)),
150        }
151    }
152}
153
154/// Splits an ANDed filter expression into equijoin predicates and remaining filters.
155/// Returns all equijoin predicates and the remaining filters combined with AND.
156///
157/// # Example
158///
159/// For the expression `a.id = b.id AND a.x > 10 AND b.x > b.id`, this function will extract `a.id = b.id` as an equijoin predicate.
160///
161/// It first splits the ANDed sub-expressions:
162/// - expr1: a.id = b.id
163/// - expr2: a.x > 10
164/// - expr3: b.x > b.id
165///
166/// Then, it filters out the equijoin predicates and collects the non-equality expressions.
167/// The equijoin condition is:
168/// - It is an equality expression like `lhs == rhs`
169/// - All column references in `lhs` are from the left schema, and all in `rhs` are from the right schema
170///
171/// According to the above rule, `expr1` is the equijoin predicate, while `expr2` and `expr3` are not.
172/// The function returns Ok(\[expr1\], Some(expr2 AND expr3))
173fn split_eq_and_noneq_join_predicate(
174    filter: Expr,
175    left_schema: &DFSchema,
176    right_schema: &DFSchema,
177) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
178    split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq)
179}
180
181/// See `split_eq_and_noneq_join_predicate`'s comment for the idea. This function
182/// is splitting out `is not distinct from` expressions instead of equal exprs.
183/// The `is not distinct from` exprs will be return as `EquijoinPredicate`.
184///
185/// # Example
186/// - Input: `a.id IS NOT DISTINCT FROM b.id AND a.x > 10 AND b.x > b.id`
187/// - Output from this splitter: `Ok([a.id, b.id], Some((a.x > 10) AND (b.x > b.id)))`
188///
189/// # Note
190/// Caller should be cautious -- `is not distinct from` is not equivalent to an
191/// equal expression; the caller is responsible for correctly setting the
192/// `nulls equals nulls` property in the join operator (if it supports it) to
193/// make the transformation valid.
194///
195/// For the above example: in downstream, a valid plan that uses the extracted
196/// equijoin keys should look like:
197///
198/// HashJoin
199/// - on: `a.id = b.id` (equality)
200/// - join_filter: `(a.x > 10) AND (b.x > b.id)`
201/// - nulls_equals_null: `true`
202///
203/// This reflects that `IS NOT DISTINCT FROM` treats `NULL = NULL` as true and
204/// thus requires setting `NullEquality::NullEqualsNull` in the join operator to
205/// preserve semantics while enabling an equi-join implementation (e.g., HashJoin).
206fn split_is_not_distinct_from_and_other_join_predicate(
207    filter: Expr,
208    left_schema: &DFSchema,
209    right_schema: &DFSchema,
210) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
211    split_op_and_other_join_predicates(
212        filter,
213        left_schema,
214        right_schema,
215        Operator::IsNotDistinctFrom,
216    )
217}
218
219/// See comments in `split_eq_and_noneq_join_predicate` for details.
220fn split_op_and_other_join_predicates(
221    filter: Expr,
222    left_schema: &DFSchema,
223    right_schema: &DFSchema,
224    operator: Operator,
225) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
226    if !matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom) {
227        return internal_err!(
228            "split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \
229            but received: {:?}",
230            operator
231        );
232    }
233
234    let exprs = split_conjunction_owned(filter);
235
236    // Treat 'is not distinct from' comparison as join key in equal joins
237    let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
238    let mut accum_filters: Vec<Expr> = vec![];
239    for expr in exprs {
240        match expr {
241            Expr::BinaryExpr(BinaryExpr {
242                ref left,
243                ref op,
244                ref right,
245            }) if *op == operator => {
246                let join_key_pair =
247                    find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
248
249                if let Some((left_expr, right_expr)) = join_key_pair {
250                    let left_expr_type = left_expr.get_type(left_schema)?;
251                    let right_expr_type = right_expr.get_type(right_schema)?;
252
253                    if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
254                        accum_join_keys.push((left_expr, right_expr));
255                    } else {
256                        accum_filters.push(expr);
257                    }
258                } else {
259                    accum_filters.push(expr);
260                }
261            }
262            _ => accum_filters.push(expr),
263        }
264    }
265
266    let result_filter = accum_filters.into_iter().reduce(Expr::and);
267    Ok((accum_join_keys, result_filter))
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::assert_optimized_plan_eq_display_indent_snapshot;
274    use crate::test::*;
275    use arrow::datatypes::DataType;
276    use datafusion_expr::{
277        col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
278    };
279    use std::sync::Arc;
280
281    macro_rules! assert_optimized_plan_equal {
282        (
283            $plan:expr,
284            @ $expected:literal $(,)?
285        ) => {{
286            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
287            assert_optimized_plan_eq_display_indent_snapshot!(
288                rule,
289                $plan,
290                @ $expected,
291            )
292        }};
293    }
294
295    #[test]
296    fn join_with_only_column_equi_predicate() -> Result<()> {
297        let t1 = test_table_scan_with_name("t1")?;
298        let t2 = test_table_scan_with_name("t2")?;
299
300        let plan = LogicalPlanBuilder::from(t1)
301            .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
302            .build()?;
303
304        assert_optimized_plan_equal!(
305            plan,
306            @r"
307        Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
308          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
309          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
310        "
311        )
312    }
313
314    #[test]
315    fn join_with_only_equi_expr_predicate() -> Result<()> {
316        let t1 = test_table_scan_with_name("t1")?;
317        let t2 = test_table_scan_with_name("t2")?;
318
319        let plan = LogicalPlanBuilder::from(t1)
320            .join_on(
321                t2,
322                JoinType::Left,
323                Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
324            )?
325            .build()?;
326
327        assert_optimized_plan_equal!(
328            plan,
329            @r"
330        Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
331          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
332          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
333        "
334        )
335    }
336
337    #[test]
338    fn join_with_only_none_equi_predicate() -> Result<()> {
339        let t1 = test_table_scan_with_name("t1")?;
340        let t2 = test_table_scan_with_name("t2")?;
341
342        let plan = LogicalPlanBuilder::from(t1)
343            .join_on(
344                t2,
345                JoinType::Left,
346                Some(
347                    (col("t1.a") + lit(10i64))
348                        .gt_eq(col("t2.a") * lit(2u32))
349                        .and(col("t1.b").lt(lit(100i32))),
350                ),
351            )?
352            .build()?;
353
354        assert_optimized_plan_equal!(
355            plan,
356            @r"
357        Left Join:  Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
358          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
359          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
360        "
361        )
362    }
363
364    #[test]
365    fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
366        let t1 = test_table_scan_with_name("t1")?;
367        let t2 = test_table_scan_with_name("t2")?;
368
369        let plan = LogicalPlanBuilder::from(t1)
370            .join_with_expr_keys(
371                t2,
372                JoinType::Left,
373                (
374                    vec![col("t1.a") + lit(11u32)],
375                    vec![col("t2.a") * lit(2u32)],
376                ),
377                Some(
378                    (col("t1.a") + lit(10i64))
379                        .eq(col("t2.a") * lit(2u32))
380                        .and(col("t1.b").lt(lit(100i32))),
381                ),
382            )?
383            .build()?;
384
385        assert_optimized_plan_equal!(
386            plan,
387            @r"
388        Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
389          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
390          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
391        "
392        )
393    }
394
395    #[test]
396    fn join_with_and_or_filter() -> Result<()> {
397        let t1 = test_table_scan_with_name("t1")?;
398        let t2 = test_table_scan_with_name("t2")?;
399
400        let plan = LogicalPlanBuilder::from(t1)
401            .join_on(
402                t2,
403                JoinType::Left,
404                Some(
405                    col("t1.c")
406                        .eq(col("t2.c"))
407                        .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
408                        .and(
409                            col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
410                        ),
411                ),
412            )?
413            .build()?;
414
415        assert_optimized_plan_equal!(
416            plan,
417            @r"
418        Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
419          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
420          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
421        "
422        )
423    }
424
425    #[test]
426    fn join_with_multiple_table() -> Result<()> {
427        let t1 = test_table_scan_with_name("t1")?;
428        let t2 = test_table_scan_with_name("t2")?;
429        let t3 = test_table_scan_with_name("t3")?;
430
431        let input = LogicalPlanBuilder::from(t2)
432            .join_on(
433                t3,
434                JoinType::Left,
435                Some(
436                    col("t2.a")
437                        .eq(col("t3.a"))
438                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
439                ),
440            )?
441            .build()?;
442        let plan = LogicalPlanBuilder::from(t1)
443            .join_on(
444                input,
445                JoinType::Left,
446                Some(
447                    col("t1.a")
448                        .eq(col("t2.a"))
449                        .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
450                ),
451            )?
452            .build()?;
453
454        assert_optimized_plan_equal!(
455            plan,
456            @r"
457        Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
458          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
459          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
460            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
461            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
462        "
463        )
464    }
465
466    #[test]
467    fn join_with_multiple_table_and_eq_filter() -> Result<()> {
468        let t1 = test_table_scan_with_name("t1")?;
469        let t2 = test_table_scan_with_name("t2")?;
470        let t3 = test_table_scan_with_name("t3")?;
471
472        let input = LogicalPlanBuilder::from(t2)
473            .join_on(
474                t3,
475                JoinType::Left,
476                Some(
477                    col("t2.a")
478                        .eq(col("t3.a"))
479                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
480                ),
481            )?
482            .build()?;
483        let plan = LogicalPlanBuilder::from(t1)
484            .join_on(
485                input,
486                JoinType::Left,
487                Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
488            )?
489            .build()?;
490
491        assert_optimized_plan_equal!(
492            plan,
493            @r"
494        Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
495          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
496          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
497            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
498            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
499        "
500        )
501    }
502
503    #[test]
504    fn join_with_alias_filter() -> Result<()> {
505        let t1 = test_table_scan_with_name("t1")?;
506        let t2 = test_table_scan_with_name("t2")?;
507
508        let t1_schema = Arc::clone(t1.schema());
509        let t2_schema = Arc::clone(t2.schema());
510
511        // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2
512        let filter = Expr::eq(
513            col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
514            col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
515        )
516        .alias("t1.a + 1 = t2.a + 2");
517        let plan = LogicalPlanBuilder::from(t1)
518            .join_on(t2, JoinType::Left, Some(filter))?
519            .build()?;
520
521        assert_optimized_plan_equal!(
522            plan,
523            @r"
524        Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
525          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
526          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
527        "
528        )
529    }
530}