datafusion_optimizer/
eliminate_outer_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins
19use crate::{OptimizerConfig, OptimizerRule};
20use datafusion_common::{Column, DFSchema, Result};
21use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan};
22use datafusion_expr::{Expr, Filter, Operator};
23
24use crate::optimizer::ApplyOrder;
25use datafusion_common::tree_node::Transformed;
26use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
27use std::sync::Arc;
28
29///
30/// Attempt to replace outer joins with inner joins.
31///
32/// Outer joins are typically more expensive to compute at runtime
33/// than inner joins and prevent various forms of predicate pushdown
34/// and other optimizations, so removing them if possible is beneficial.
35///
36/// Inner joins filter out rows that do match. Outer joins pass rows
37/// that do not match padded with nulls. If there is a filter in the
38/// query that would filter any such null rows after the join the rows
39/// introduced by the outer join are filtered.
40///
41/// For example, in the `select ... from a left join b on ... where b.xx = 100;`
42///
43/// For rows when `b.xx` is null (as it would be after an outer join),
44/// the `b.xx = 100` predicate filters them out and there is no
45/// need to produce null rows for output.
46///
47/// Generally, an outer join can be rewritten to inner join if the
48/// filters from the WHERE clause return false while any inputs are
49/// null and columns of those quals are come from nullable side of
50/// outer join.
51#[derive(Default, Debug)]
52pub struct EliminateOuterJoin;
53
54impl EliminateOuterJoin {
55    #[allow(missing_docs)]
56    pub fn new() -> Self {
57        Self {}
58    }
59}
60
61/// Attempt to eliminate outer joins.
62impl OptimizerRule for EliminateOuterJoin {
63    fn name(&self) -> &str {
64        "eliminate_outer_join"
65    }
66
67    fn apply_order(&self) -> Option<ApplyOrder> {
68        Some(ApplyOrder::TopDown)
69    }
70
71    fn supports_rewrite(&self) -> bool {
72        true
73    }
74
75    fn rewrite(
76        &self,
77        plan: LogicalPlan,
78        _config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        match plan {
81            LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) {
82                LogicalPlan::Join(join) => {
83                    let mut non_nullable_cols: Vec<Column> = vec![];
84
85                    extract_non_nullable_columns(
86                        &filter.predicate,
87                        &mut non_nullable_cols,
88                        join.left.schema(),
89                        join.right.schema(),
90                        true,
91                    );
92
93                    let new_join_type = if join.join_type.is_outer() {
94                        let mut left_non_nullable = false;
95                        let mut right_non_nullable = false;
96                        for col in non_nullable_cols.iter() {
97                            if join.left.schema().has_column(col) {
98                                left_non_nullable = true;
99                            }
100                            if join.right.schema().has_column(col) {
101                                right_non_nullable = true;
102                            }
103                        }
104                        eliminate_outer(
105                            join.join_type,
106                            left_non_nullable,
107                            right_non_nullable,
108                        )
109                    } else {
110                        join.join_type
111                    };
112
113                    let new_join = Arc::new(LogicalPlan::Join(Join {
114                        left: join.left,
115                        right: join.right,
116                        join_type: new_join_type,
117                        join_constraint: join.join_constraint,
118                        on: join.on.clone(),
119                        filter: join.filter.clone(),
120                        schema: Arc::clone(&join.schema),
121                        null_equality: join.null_equality,
122                    }));
123                    Filter::try_new(filter.predicate, new_join)
124                        .map(|f| Transformed::yes(LogicalPlan::Filter(f)))
125                }
126                filter_input => {
127                    filter.input = Arc::new(filter_input);
128                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
129                }
130            },
131            _ => Ok(Transformed::no(plan)),
132        }
133    }
134}
135
136pub fn eliminate_outer(
137    join_type: JoinType,
138    left_non_nullable: bool,
139    right_non_nullable: bool,
140) -> JoinType {
141    let mut new_join_type = join_type;
142    match join_type {
143        JoinType::Left => {
144            if right_non_nullable {
145                new_join_type = JoinType::Inner;
146            }
147        }
148        JoinType::Right => {
149            if left_non_nullable {
150                new_join_type = JoinType::Inner;
151            }
152        }
153        JoinType::Full => {
154            if left_non_nullable && right_non_nullable {
155                new_join_type = JoinType::Inner;
156            } else if left_non_nullable {
157                new_join_type = JoinType::Left;
158            } else if right_non_nullable {
159                new_join_type = JoinType::Right;
160            }
161        }
162        _ => {}
163    }
164    new_join_type
165}
166
167/// Recursively traverses expr, if expr returns false when
168/// any inputs are null, treats columns of both sides as non_nullable columns.
169///
170/// For and/or expr, extracts from all sub exprs and merges the columns.
171/// For or expr, if one of sub exprs returns true, discards all columns from or expr.
172/// For IS NOT NULL/NOT expr, always returns false for NULL input.
173///     extracts columns from these exprs.
174/// For all other exprs, fall through
175fn extract_non_nullable_columns(
176    expr: &Expr,
177    non_nullable_cols: &mut Vec<Column>,
178    left_schema: &Arc<DFSchema>,
179    right_schema: &Arc<DFSchema>,
180    top_level: bool,
181) {
182    match expr {
183        Expr::Column(col) => {
184            non_nullable_cols.push(col.clone());
185        }
186        Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
187            // If one of the inputs are null for these operators, the results should be false.
188            Operator::Eq
189            | Operator::NotEq
190            | Operator::Lt
191            | Operator::LtEq
192            | Operator::Gt
193            | Operator::GtEq => {
194                extract_non_nullable_columns(
195                    left,
196                    non_nullable_cols,
197                    left_schema,
198                    right_schema,
199                    false,
200                );
201                extract_non_nullable_columns(
202                    right,
203                    non_nullable_cols,
204                    left_schema,
205                    right_schema,
206                    false,
207                )
208            }
209            Operator::And | Operator::Or => {
210                // treat And as Or if does not from top level, such as
211                // not (c1 < 10 and c2 > 100)
212                if top_level && *op == Operator::And {
213                    extract_non_nullable_columns(
214                        left,
215                        non_nullable_cols,
216                        left_schema,
217                        right_schema,
218                        top_level,
219                    );
220                    extract_non_nullable_columns(
221                        right,
222                        non_nullable_cols,
223                        left_schema,
224                        right_schema,
225                        top_level,
226                    );
227                    return;
228                }
229
230                let mut left_non_nullable_cols: Vec<Column> = vec![];
231                let mut right_non_nullable_cols: Vec<Column> = vec![];
232
233                extract_non_nullable_columns(
234                    left,
235                    &mut left_non_nullable_cols,
236                    left_schema,
237                    right_schema,
238                    top_level,
239                );
240                extract_non_nullable_columns(
241                    right,
242                    &mut right_non_nullable_cols,
243                    left_schema,
244                    right_schema,
245                    top_level,
246                );
247
248                // for query: select *** from a left join b where b.c1 ... or b.c2 ...
249                // this can be eliminated to inner join.
250                // for query: select *** from a left join b where a.c1 ... or b.c2 ...
251                // this can not be eliminated.
252                // If columns of relation exist in both sub exprs, any columns of this relation
253                // can be added to non nullable columns.
254                if !left_non_nullable_cols.is_empty()
255                    && !right_non_nullable_cols.is_empty()
256                {
257                    for left_col in &left_non_nullable_cols {
258                        for right_col in &right_non_nullable_cols {
259                            if (left_schema.has_column(left_col)
260                                && left_schema.has_column(right_col))
261                                || (right_schema.has_column(left_col)
262                                    && right_schema.has_column(right_col))
263                            {
264                                non_nullable_cols.push(left_col.clone());
265                                break;
266                            }
267                        }
268                    }
269                }
270            }
271            _ => {}
272        },
273        Expr::Not(arg) => extract_non_nullable_columns(
274            arg,
275            non_nullable_cols,
276            left_schema,
277            right_schema,
278            false,
279        ),
280        Expr::IsNotNull(arg) => {
281            if !top_level {
282                return;
283            }
284            extract_non_nullable_columns(
285                arg,
286                non_nullable_cols,
287                left_schema,
288                right_schema,
289                false,
290            )
291        }
292        Expr::Cast(Cast { expr, data_type: _ })
293        | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns(
294            expr,
295            non_nullable_cols,
296            left_schema,
297            right_schema,
298            false,
299        ),
300        _ => {}
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::assert_optimized_plan_eq_snapshot;
308    use crate::test::*;
309    use crate::OptimizerContext;
310    use arrow::datatypes::DataType;
311    use datafusion_expr::{
312        binary_expr, cast, col, lit,
313        logical_plan::builder::LogicalPlanBuilder,
314        try_cast,
315        Operator::{And, Or},
316    };
317
318    macro_rules! assert_optimized_plan_equal {
319        (
320            $plan:expr,
321            @ $expected:literal $(,)?
322        ) => {{
323            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
324            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateOuterJoin::new())];
325            assert_optimized_plan_eq_snapshot!(
326                optimizer_ctx,
327                rules,
328                $plan,
329                @ $expected,
330            )
331        }};
332    }
333
334    #[test]
335    fn eliminate_left_with_null() -> Result<()> {
336        let t1 = test_table_scan_with_name("t1")?;
337        let t2 = test_table_scan_with_name("t2")?;
338
339        // could not eliminate to inner join
340        let plan = LogicalPlanBuilder::from(t1)
341            .join(
342                t2,
343                JoinType::Left,
344                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
345                None,
346            )?
347            .filter(col("t2.b").is_null())?
348            .build()?;
349
350        assert_optimized_plan_equal!(plan, @r"
351        Filter: t2.b IS NULL
352          Left Join: t1.a = t2.a
353            TableScan: t1
354            TableScan: t2
355        ")
356    }
357
358    #[test]
359    fn eliminate_left_with_not_null() -> Result<()> {
360        let t1 = test_table_scan_with_name("t1")?;
361        let t2 = test_table_scan_with_name("t2")?;
362
363        // eliminate to inner join
364        let plan = LogicalPlanBuilder::from(t1)
365            .join(
366                t2,
367                JoinType::Left,
368                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
369                None,
370            )?
371            .filter(col("t2.b").is_not_null())?
372            .build()?;
373
374        assert_optimized_plan_equal!(plan, @r"
375        Filter: t2.b IS NOT NULL
376          Inner Join: t1.a = t2.a
377            TableScan: t1
378            TableScan: t2
379        ")
380    }
381
382    #[test]
383    fn eliminate_right_with_or() -> Result<()> {
384        let t1 = test_table_scan_with_name("t1")?;
385        let t2 = test_table_scan_with_name("t2")?;
386
387        // eliminate to inner join
388        let plan = LogicalPlanBuilder::from(t1)
389            .join(
390                t2,
391                JoinType::Right,
392                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
393                None,
394            )?
395            .filter(binary_expr(
396                col("t1.b").gt(lit(10u32)),
397                Or,
398                col("t1.c").lt(lit(20u32)),
399            ))?
400            .build()?;
401
402        assert_optimized_plan_equal!(plan, @r"
403        Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)
404          Inner Join: t1.a = t2.a
405            TableScan: t1
406            TableScan: t2
407        ")
408    }
409
410    #[test]
411    fn eliminate_full_with_and() -> Result<()> {
412        let t1 = test_table_scan_with_name("t1")?;
413        let t2 = test_table_scan_with_name("t2")?;
414
415        // eliminate to inner join
416        let plan = LogicalPlanBuilder::from(t1)
417            .join(
418                t2,
419                JoinType::Full,
420                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
421                None,
422            )?
423            .filter(binary_expr(
424                col("t1.b").gt(lit(10u32)),
425                And,
426                col("t2.c").lt(lit(20u32)),
427            ))?
428            .build()?;
429
430        assert_optimized_plan_equal!(plan, @r"
431        Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)
432          Inner Join: t1.a = t2.a
433            TableScan: t1
434            TableScan: t2
435        ")
436    }
437
438    #[test]
439    fn eliminate_full_with_type_cast() -> Result<()> {
440        let t1 = test_table_scan_with_name("t1")?;
441        let t2 = test_table_scan_with_name("t2")?;
442
443        // eliminate to inner join
444        let plan = LogicalPlanBuilder::from(t1)
445            .join(
446                t2,
447                JoinType::Full,
448                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
449                None,
450            )?
451            .filter(binary_expr(
452                cast(col("t1.b"), DataType::Int64).gt(lit(10u32)),
453                And,
454                try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)),
455            ))?
456            .build()?;
457
458        assert_optimized_plan_equal!(plan, @r"
459        Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)
460          Inner Join: t1.a = t2.a
461            TableScan: t1
462            TableScan: t2
463        ")
464    }
465}