datafusion_optimizer/
extract_equijoin_predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::DFSchema;
23use datafusion_common::Result;
24use datafusion_expr::utils::split_conjunction_owned;
25use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
26use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
27// equijoin predicate
28type EquijoinPredicate = (Expr, Expr);
29
30/// Optimizer that splits conjunctive join predicates into equijoin
31/// predicates and (other) filter predicates.
32///
33/// Join algorithms are often highly optimized for equality predicates such as `x = y`,
34/// often called `equijoin` predicates, so it is important to locate such predicates
35/// and treat them specially.
36///
37/// For example, `SELECT ... FROM A JOIN B ON (A.x = B.y AND B.z > 50)`
38/// has one equijoin predicate (`A.x = B.y`) and one filter predicate (`B.z > 50`).
39/// See [find_valid_equijoin_key_pair] for more information on what predicates
40/// are considered equijoins.
41#[derive(Default, Debug)]
42pub struct ExtractEquijoinPredicate;
43
44impl ExtractEquijoinPredicate {
45    #[allow(missing_docs)]
46    pub fn new() -> Self {
47        Self {}
48    }
49}
50
51impl OptimizerRule for ExtractEquijoinPredicate {
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn name(&self) -> &str {
57        "extract_equijoin_predicate"
58    }
59
60    fn apply_order(&self) -> Option<ApplyOrder> {
61        Some(ApplyOrder::BottomUp)
62    }
63
64    fn rewrite(
65        &self,
66        plan: LogicalPlan,
67        _config: &dyn OptimizerConfig,
68    ) -> Result<Transformed<LogicalPlan>> {
69        match plan {
70            LogicalPlan::Join(Join {
71                left,
72                right,
73                mut on,
74                filter: Some(expr),
75                join_type,
76                join_constraint,
77                schema,
78                null_equals_null,
79            }) => {
80                let left_schema = left.schema();
81                let right_schema = right.schema();
82                let (equijoin_predicates, non_equijoin_expr) =
83                    split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?;
84
85                if !equijoin_predicates.is_empty() {
86                    on.extend(equijoin_predicates);
87                    Ok(Transformed::yes(LogicalPlan::Join(Join {
88                        left,
89                        right,
90                        on,
91                        filter: non_equijoin_expr,
92                        join_type,
93                        join_constraint,
94                        schema,
95                        null_equals_null,
96                    })))
97                } else {
98                    Ok(Transformed::no(LogicalPlan::Join(Join {
99                        left,
100                        right,
101                        on,
102                        filter: non_equijoin_expr,
103                        join_type,
104                        join_constraint,
105                        schema,
106                        null_equals_null,
107                    })))
108                }
109            }
110            _ => Ok(Transformed::no(plan)),
111        }
112    }
113}
114
115fn split_eq_and_noneq_join_predicate(
116    filter: Expr,
117    left_schema: &DFSchema,
118    right_schema: &DFSchema,
119) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
120    let exprs = split_conjunction_owned(filter);
121
122    let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
123    let mut accum_filters: Vec<Expr> = vec![];
124    for expr in exprs {
125        match expr {
126            Expr::BinaryExpr(BinaryExpr {
127                ref left,
128                op: Operator::Eq,
129                ref right,
130            }) => {
131                let join_key_pair =
132                    find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?;
133
134                if let Some((left_expr, right_expr)) = join_key_pair {
135                    let left_expr_type = left_expr.get_type(left_schema)?;
136                    let right_expr_type = right_expr.get_type(right_schema)?;
137
138                    if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
139                        accum_join_keys.push((left_expr, right_expr));
140                    } else {
141                        accum_filters.push(expr);
142                    }
143                } else {
144                    accum_filters.push(expr);
145                }
146            }
147            _ => accum_filters.push(expr),
148        }
149    }
150
151    let result_filter = accum_filters.into_iter().reduce(Expr::and);
152    Ok((accum_join_keys, result_filter))
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::test::*;
159    use arrow::datatypes::DataType;
160    use datafusion_expr::{
161        col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
162    };
163    use std::sync::Arc;
164
165    fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
166        assert_optimized_plan_eq_display_indent(
167            Arc::new(ExtractEquijoinPredicate {}),
168            plan,
169            expected,
170        );
171
172        Ok(())
173    }
174
175    #[test]
176    fn join_with_only_column_equi_predicate() -> Result<()> {
177        let t1 = test_table_scan_with_name("t1")?;
178        let t2 = test_table_scan_with_name("t2")?;
179
180        let plan = LogicalPlanBuilder::from(t1)
181            .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
182            .build()?;
183        let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
184            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
185            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
186
187        assert_plan_eq(plan, expected)
188    }
189
190    #[test]
191    fn join_with_only_equi_expr_predicate() -> Result<()> {
192        let t1 = test_table_scan_with_name("t1")?;
193        let t2 = test_table_scan_with_name("t2")?;
194
195        let plan = LogicalPlanBuilder::from(t1)
196            .join_on(
197                t2,
198                JoinType::Left,
199                Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
200            )?
201            .build()?;
202        let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
203            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
204            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
205
206        assert_plan_eq(plan, expected)
207    }
208
209    #[test]
210    fn join_with_only_none_equi_predicate() -> Result<()> {
211        let t1 = test_table_scan_with_name("t1")?;
212        let t2 = test_table_scan_with_name("t2")?;
213
214        let plan = LogicalPlanBuilder::from(t1)
215            .join_on(
216                t2,
217                JoinType::Left,
218                Some(
219                    (col("t1.a") + lit(10i64))
220                        .gt_eq(col("t2.a") * lit(2u32))
221                        .and(col("t1.b").lt(lit(100i32))),
222                ),
223            )?
224            .build()?;
225        let expected = "Left Join:  Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
226            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
227            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
228
229        assert_plan_eq(plan, expected)
230    }
231
232    #[test]
233    fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
234        let t1 = test_table_scan_with_name("t1")?;
235        let t2 = test_table_scan_with_name("t2")?;
236
237        let plan = LogicalPlanBuilder::from(t1)
238            .join_with_expr_keys(
239                t2,
240                JoinType::Left,
241                (
242                    vec![col("t1.a") + lit(11u32)],
243                    vec![col("t2.a") * lit(2u32)],
244                ),
245                Some(
246                    (col("t1.a") + lit(10i64))
247                        .eq(col("t2.a") * lit(2u32))
248                        .and(col("t1.b").lt(lit(100i32))),
249                ),
250            )?
251            .build()?;
252        let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
253            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
254            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
255
256        assert_plan_eq(plan, expected)
257    }
258
259    #[test]
260    fn join_with_and_or_filter() -> Result<()> {
261        let t1 = test_table_scan_with_name("t1")?;
262        let t2 = test_table_scan_with_name("t2")?;
263
264        let plan = LogicalPlanBuilder::from(t1)
265            .join_on(
266                t2,
267                JoinType::Left,
268                Some(
269                    col("t1.c")
270                        .eq(col("t2.c"))
271                        .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + col("t2.c")))
272                        .and(
273                            col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
274                        ),
275                ),
276            )?
277            .build()?;
278        let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
279            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
280            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
281
282        assert_plan_eq(plan, expected)
283    }
284
285    #[test]
286    fn join_with_multiple_table() -> Result<()> {
287        let t1 = test_table_scan_with_name("t1")?;
288        let t2 = test_table_scan_with_name("t2")?;
289        let t3 = test_table_scan_with_name("t3")?;
290
291        let input = LogicalPlanBuilder::from(t2)
292            .join_on(
293                t3,
294                JoinType::Left,
295                Some(
296                    col("t2.a")
297                        .eq(col("t3.a"))
298                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
299                ),
300            )?
301            .build()?;
302        let plan = LogicalPlanBuilder::from(t1)
303            .join_on(
304                input,
305                JoinType::Left,
306                Some(
307                    col("t1.a")
308                        .eq(col("t2.a"))
309                        .and((col("t1.c") + col("t2.c") + col("t3.c")).lt(lit(100u32))),
310                ),
311            )?
312            .build()?;
313        let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
314            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
315            \n  Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
316            \n    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
317            \n    TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
318
319        assert_plan_eq(plan, expected)
320    }
321
322    #[test]
323    fn join_with_multiple_table_and_eq_filter() -> Result<()> {
324        let t1 = test_table_scan_with_name("t1")?;
325        let t2 = test_table_scan_with_name("t2")?;
326        let t3 = test_table_scan_with_name("t3")?;
327
328        let input = LogicalPlanBuilder::from(t2)
329            .join_on(
330                t3,
331                JoinType::Left,
332                Some(
333                    col("t2.a")
334                        .eq(col("t3.a"))
335                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
336                ),
337            )?
338            .build()?;
339        let plan = LogicalPlanBuilder::from(t1)
340            .join_on(
341                input,
342                JoinType::Left,
343                Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
344            )?
345            .build()?;
346        let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
347        \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
348        \n  Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
349        \n    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
350        \n    TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
351
352        assert_plan_eq(plan, expected)
353    }
354
355    #[test]
356    fn join_with_alias_filter() -> Result<()> {
357        let t1 = test_table_scan_with_name("t1")?;
358        let t2 = test_table_scan_with_name("t2")?;
359
360        let t1_schema = Arc::clone(t1.schema());
361        let t2_schema = Arc::clone(t2.schema());
362
363        // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2
364        let filter = Expr::eq(
365            col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
366            col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
367        )
368        .alias("t1.a + 1 = t2.a + 2");
369        let plan = LogicalPlanBuilder::from(t1)
370            .join_on(t2, JoinType::Left, Some(filter))?
371            .build()?;
372        let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
373        \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
374        \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
375
376        assert_plan_eq(plan, expected)
377    }
378}