datafusion_optimizer/
extract_equijoin_predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::DFSchema;
23use datafusion_common::Result;
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27// equijoin predicate
28type EquijoinPredicate = (Expr, Expr);
29
30/// Optimizer that splits conjunctive join predicates into equijoin
31/// predicates and (other) filter predicates.
32///
33/// Join algorithms are often highly optimized for equality predicates such as `x = y`,
34/// often called `equijoin` predicates, so it is important to locate such predicates
35/// and treat them specially.
36///
37/// For example, `SELECT ... FROM A JOIN B ON (A.x = B.y AND B.z > 50)`
38/// has one equijoin predicate (`A.x = B.y`) and one filter predicate (`B.z > 50`).
39/// See [find_valid_equijoin_key_pair] for more information on what predicates
40/// are considered equijoins.
41#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45    #[allow(missing_docs)]
46    pub fn new() -> Self {
47        Self {}
48    }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn name(&self) -> &str {
57        "extract_equijoin_predicate"
58    }
59
60    fn apply_order(&self) -> Option<ApplyOrder> {
61        Some(ApplyOrder::BottomUp)
62    }
63
64    fn rewrite(
65        &self,
66        plan: LogicalPlan,
67        _config: &dyn OptimizerConfig,
68    ) -> Result<Transformed<LogicalPlan>> {
69        match plan {
70            LogicalPlan::Join(Join {
71                left,
72                right,
73                mut on,
74                filter: Some(expr),
75                join_type,
76                join_constraint,
77                schema,
78                null_equality,
79            }) => {
80                let left_schema = left.schema();
81                let right_schema = right.schema();
82                let (equijoin_predicates, non_equijoin_expr) =
83                    split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85                if !equijoin_predicates.is_empty() {
86                    on.extend(equijoin_predicates);
87                    Ok(Transformed::yes(LogicalPlan::Join(Join {
88                        left,
89                        right,
90                        on,
91                        filter: non_equijoin_expr,
92                        join_type,
93                        join_constraint,
94                        schema,
95                        null_equality,
96                    })))
97                } else {
98                    Ok(Transformed::no(LogicalPlan::Join(Join {
99                        left,
100                        right,
101                        on,
102                        filter: non_equijoin_expr,
103                        join_type,
104                        join_constraint,
105                        schema,
106                        null_equality,
107                    })))
108                }
109            }
110            _ => Ok(Transformed::no(plan)),
111        }
112    }
113}
114
115fn split_eq_and_noneq_join_predicate(
116    filter: Expr,
117    left_schema: &DFSchema,
118    right_schema: &DFSchema,
119) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
120    let exprs = split_conjunction_owned(filter);
121
122    let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
123    let mut accum_filters: Vec<Expr> = vec![];
124    for expr in exprs {
125        match expr {
126            Expr::BinaryExpr(BinaryExpr {
127                ref left,
128                op: Operator::Eq,
129                ref right,
130            }) => {
131                let join_key_pair =
132                    find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
133
134                if let Some((left_expr, right_expr)) = join_key_pair {
135                    let left_expr_type = left_expr.get_type(left_schema)?;
136                    let right_expr_type = right_expr.get_type(right_schema)?;
137
138                    if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
139                        accum_join_keys.push((left_expr, right_expr));
140                    } else {
141                        accum_filters.push(expr);
142                    }
143                } else {
144                    accum_filters.push(expr);
145                }
146            }
147            _ => accum_filters.push(expr),
148        }
149    }
150
151    let result_filter = accum_filters.into_iter().reduce(Expr::and);
152    Ok((accum_join_keys, result_filter))
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::assert_optimized_plan_eq_display_indent_snapshot;
159    use crate::test::*;
160    use arrow::datatypes::DataType;
161    use datafusion_expr::{
162        col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
163    };
164    use std::sync::Arc;
165
166    macro_rules! assert_optimized_plan_equal {
167        (
168            $plan:expr,
169            @ $expected:literal $(,)?
170        ) => {{
171            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
172            assert_optimized_plan_eq_display_indent_snapshot!(
173                rule,
174                $plan,
175                @ $expected,
176            )
177        }};
178    }
179
180    #[test]
181    fn join_with_only_column_equi_predicate() -> Result<()> {
182        let t1 = test_table_scan_with_name("t1")?;
183        let t2 = test_table_scan_with_name("t2")?;
184
185        let plan = LogicalPlanBuilder::from(t1)
186            .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
187            .build()?;
188
189        assert_optimized_plan_equal!(
190            plan,
191            @r"
192        Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
193          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
194          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
195        "
196        )
197    }
198
199    #[test]
200    fn join_with_only_equi_expr_predicate() -> Result<()> {
201        let t1 = test_table_scan_with_name("t1")?;
202        let t2 = test_table_scan_with_name("t2")?;
203
204        let plan = LogicalPlanBuilder::from(t1)
205            .join_on(
206                t2,
207                JoinType::Left,
208                Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
209            )?
210            .build()?;
211
212        assert_optimized_plan_equal!(
213            plan,
214            @r"
215        Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
216          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
217          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
218        "
219        )
220    }
221
222    #[test]
223    fn join_with_only_none_equi_predicate() -> Result<()> {
224        let t1 = test_table_scan_with_name("t1")?;
225        let t2 = test_table_scan_with_name("t2")?;
226
227        let plan = LogicalPlanBuilder::from(t1)
228            .join_on(
229                t2,
230                JoinType::Left,
231                Some(
232                    (col("t1.a") + lit(10i64))
233                        .gt_eq(col("t2.a") * lit(2u32))
234                        .and(col("t1.b").lt(lit(100i32))),
235                ),
236            )?
237            .build()?;
238
239        assert_optimized_plan_equal!(
240            plan,
241            @r"
242        Left Join:  Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
243          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
244          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
245        "
246        )
247    }
248
249    #[test]
250    fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
251        let t1 = test_table_scan_with_name("t1")?;
252        let t2 = test_table_scan_with_name("t2")?;
253
254        let plan = LogicalPlanBuilder::from(t1)
255            .join_with_expr_keys(
256                t2,
257                JoinType::Left,
258                (
259                    vec![col("t1.a") + lit(11u32)],
260                    vec![col("t2.a") * lit(2u32)],
261                ),
262                Some(
263                    (col("t1.a") + lit(10i64))
264                        .eq(col("t2.a") * lit(2u32))
265                        .and(col("t1.b").lt(lit(100i32))),
266                ),
267            )?
268            .build()?;
269
270        assert_optimized_plan_equal!(
271            plan,
272            @r"
273        Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
274          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
275          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
276        "
277        )
278    }
279
280    #[test]
281    fn join_with_and_or_filter() -> Result<()> {
282        let t1 = test_table_scan_with_name("t1")?;
283        let t2 = test_table_scan_with_name("t2")?;
284
285        let plan = LogicalPlanBuilder::from(t1)
286            .join_on(
287                t2,
288                JoinType::Left,
289                Some(
290                    col("t1.c")
291                        .eq(col("t2.c"))
292                        .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
293                        .and(
294                            col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
295                        ),
296                ),
297            )?
298            .build()?;
299
300        assert_optimized_plan_equal!(
301            plan,
302            @r"
303        Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
304          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
305          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
306        "
307        )
308    }
309
310    #[test]
311    fn join_with_multiple_table() -> Result<()> {
312        let t1 = test_table_scan_with_name("t1")?;
313        let t2 = test_table_scan_with_name("t2")?;
314        let t3 = test_table_scan_with_name("t3")?;
315
316        let input = LogicalPlanBuilder::from(t2)
317            .join_on(
318                t3,
319                JoinType::Left,
320                Some(
321                    col("t2.a")
322                        .eq(col("t3.a"))
323                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
324                ),
325            )?
326            .build()?;
327        let plan = LogicalPlanBuilder::from(t1)
328            .join_on(
329                input,
330                JoinType::Left,
331                Some(
332                    col("t1.a")
333                        .eq(col("t2.a"))
334                        .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
335                ),
336            )?
337            .build()?;
338
339        assert_optimized_plan_equal!(
340            plan,
341            @r"
342        Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
343          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
344          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
345            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
346            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
347        "
348        )
349    }
350
351    #[test]
352    fn join_with_multiple_table_and_eq_filter() -> Result<()> {
353        let t1 = test_table_scan_with_name("t1")?;
354        let t2 = test_table_scan_with_name("t2")?;
355        let t3 = test_table_scan_with_name("t3")?;
356
357        let input = LogicalPlanBuilder::from(t2)
358            .join_on(
359                t3,
360                JoinType::Left,
361                Some(
362                    col("t2.a")
363                        .eq(col("t3.a"))
364                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
365                ),
366            )?
367            .build()?;
368        let plan = LogicalPlanBuilder::from(t1)
369            .join_on(
370                input,
371                JoinType::Left,
372                Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
373            )?
374            .build()?;
375
376        assert_optimized_plan_equal!(
377            plan,
378            @r"
379        Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
380          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
381          Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
382            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
383            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
384        "
385        )
386    }
387
388    #[test]
389    fn join_with_alias_filter() -> Result<()> {
390        let t1 = test_table_scan_with_name("t1")?;
391        let t2 = test_table_scan_with_name("t2")?;
392
393        let t1_schema = Arc::clone(t1.schema());
394        let t2_schema = Arc::clone(t2.schema());
395
396        // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2
397        let filter = Expr::eq(
398            col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
399            col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
400        )
401        .alias("t1.a + 1 = t2.a + 2");
402        let plan = LogicalPlanBuilder::from(t1)
403            .join_on(t2, JoinType::Left, Some(filter))?
404            .build()?;
405
406        assert_optimized_plan_equal!(
407            plan,
408            @r"
409        Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
410          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
411          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
412        "
413        )
414    }
415}