Skip to main content

datafusion_optimizer/
eliminate_outer_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins
19use crate::{OptimizerConfig, OptimizerRule};
20use datafusion_common::{Column, DFSchema, Result};
21use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan};
22use datafusion_expr::{Expr, Filter, Operator};
23
24use crate::optimizer::ApplyOrder;
25use datafusion_common::tree_node::Transformed;
26use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
27use std::sync::Arc;
28
29///
30/// Attempt to replace outer joins with inner joins.
31///
32/// Outer joins are typically more expensive to compute at runtime
33/// than inner joins and prevent various forms of predicate pushdown
34/// and other optimizations, so removing them if possible is beneficial.
35///
36/// Inner joins filter out rows that do match. Outer joins pass rows
37/// that do not match padded with nulls. If there is a filter in the
38/// query that would filter any such null rows after the join the rows
39/// introduced by the outer join are filtered.
40///
41/// For example, in the `select ... from a left join b on ... where b.xx = 100;`
42///
43/// For rows when `b.xx` is null (as it would be after an outer join),
44/// the `b.xx = 100` predicate filters them out and there is no
45/// need to produce null rows for output.
46///
47/// Generally, an outer join can be rewritten to inner join if the
48/// filters from the WHERE clause return false while any inputs are
49/// null and columns of those quals are come from nullable side of
50/// outer join.
51#[derive(Default, Debug)]
52pub struct EliminateOuterJoin;
53
54impl EliminateOuterJoin {
55    #[expect(missing_docs)]
56    pub fn new() -> Self {
57        Self {}
58    }
59}
60
61/// Attempt to eliminate outer joins.
62impl OptimizerRule for EliminateOuterJoin {
63    fn name(&self) -> &str {
64        "eliminate_outer_join"
65    }
66
67    fn apply_order(&self) -> Option<ApplyOrder> {
68        Some(ApplyOrder::TopDown)
69    }
70
71    fn supports_rewrite(&self) -> bool {
72        true
73    }
74
75    fn rewrite(
76        &self,
77        plan: LogicalPlan,
78        _config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        match plan {
81            LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) {
82                LogicalPlan::Join(join) => {
83                    let mut non_nullable_cols: Vec<Column> = vec![];
84
85                    extract_non_nullable_columns(
86                        &filter.predicate,
87                        &mut non_nullable_cols,
88                        join.left.schema(),
89                        join.right.schema(),
90                        true,
91                    );
92
93                    let new_join_type = if join.join_type.is_outer() {
94                        let mut left_non_nullable = false;
95                        let mut right_non_nullable = false;
96                        for col in non_nullable_cols.iter() {
97                            if join.left.schema().has_column(col) {
98                                left_non_nullable = true;
99                            }
100                            if join.right.schema().has_column(col) {
101                                right_non_nullable = true;
102                            }
103                        }
104                        eliminate_outer(
105                            join.join_type,
106                            left_non_nullable,
107                            right_non_nullable,
108                        )
109                    } else {
110                        join.join_type
111                    };
112
113                    let new_join = Arc::new(LogicalPlan::Join(Join {
114                        left: join.left,
115                        right: join.right,
116                        join_type: new_join_type,
117                        join_constraint: join.join_constraint,
118                        on: join.on.clone(),
119                        filter: join.filter.clone(),
120                        schema: Arc::clone(&join.schema),
121                        null_equality: join.null_equality,
122                        null_aware: join.null_aware,
123                    }));
124                    Filter::try_new(filter.predicate, new_join)
125                        .map(|f| Transformed::yes(LogicalPlan::Filter(f)))
126                }
127                filter_input => {
128                    filter.input = Arc::new(filter_input);
129                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
130                }
131            },
132            _ => Ok(Transformed::no(plan)),
133        }
134    }
135}
136
137pub fn eliminate_outer(
138    join_type: JoinType,
139    left_non_nullable: bool,
140    right_non_nullable: bool,
141) -> JoinType {
142    let mut new_join_type = join_type;
143    match join_type {
144        JoinType::Left => {
145            if right_non_nullable {
146                new_join_type = JoinType::Inner;
147            }
148        }
149        JoinType::Right => {
150            if left_non_nullable {
151                new_join_type = JoinType::Inner;
152            }
153        }
154        JoinType::Full => {
155            if left_non_nullable && right_non_nullable {
156                new_join_type = JoinType::Inner;
157            } else if left_non_nullable {
158                new_join_type = JoinType::Left;
159            } else if right_non_nullable {
160                new_join_type = JoinType::Right;
161            }
162        }
163        _ => {}
164    }
165    new_join_type
166}
167
168/// Recursively traverses expr, if expr returns false when
169/// any inputs are null, treats columns of both sides as non_nullable columns.
170///
171/// For and/or expr, extracts from all sub exprs and merges the columns.
172/// For or expr, if one of sub exprs returns true, discards all columns from or expr.
173/// For IS NOT NULL/NOT expr, always returns false for NULL input.
174///     extracts columns from these exprs.
175/// For all other exprs, fall through
176fn extract_non_nullable_columns(
177    expr: &Expr,
178    non_nullable_cols: &mut Vec<Column>,
179    left_schema: &Arc<DFSchema>,
180    right_schema: &Arc<DFSchema>,
181    top_level: bool,
182) {
183    match expr {
184        Expr::Column(col) => {
185            non_nullable_cols.push(col.clone());
186        }
187        Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
188            // If one of the inputs are null for these operators, the results should be false.
189            Operator::Eq
190            | Operator::NotEq
191            | Operator::Lt
192            | Operator::LtEq
193            | Operator::Gt
194            | Operator::GtEq => {
195                extract_non_nullable_columns(
196                    left,
197                    non_nullable_cols,
198                    left_schema,
199                    right_schema,
200                    false,
201                );
202                extract_non_nullable_columns(
203                    right,
204                    non_nullable_cols,
205                    left_schema,
206                    right_schema,
207                    false,
208                )
209            }
210            Operator::And | Operator::Or => {
211                // treat And as Or if does not from top level, such as
212                // not (c1 < 10 and c2 > 100)
213                if top_level && *op == Operator::And {
214                    extract_non_nullable_columns(
215                        left,
216                        non_nullable_cols,
217                        left_schema,
218                        right_schema,
219                        top_level,
220                    );
221                    extract_non_nullable_columns(
222                        right,
223                        non_nullable_cols,
224                        left_schema,
225                        right_schema,
226                        top_level,
227                    );
228                    return;
229                }
230
231                let mut left_non_nullable_cols: Vec<Column> = vec![];
232                let mut right_non_nullable_cols: Vec<Column> = vec![];
233
234                extract_non_nullable_columns(
235                    left,
236                    &mut left_non_nullable_cols,
237                    left_schema,
238                    right_schema,
239                    top_level,
240                );
241                extract_non_nullable_columns(
242                    right,
243                    &mut right_non_nullable_cols,
244                    left_schema,
245                    right_schema,
246                    top_level,
247                );
248
249                // for query: select *** from a left join b where b.c1 ... or b.c2 ...
250                // this can be eliminated to inner join.
251                // for query: select *** from a left join b where a.c1 ... or b.c2 ...
252                // this can not be eliminated.
253                // If columns of relation exist in both sub exprs, any columns of this relation
254                // can be added to non nullable columns.
255                if !left_non_nullable_cols.is_empty()
256                    && !right_non_nullable_cols.is_empty()
257                {
258                    for left_col in &left_non_nullable_cols {
259                        for right_col in &right_non_nullable_cols {
260                            if (left_schema.has_column(left_col)
261                                && left_schema.has_column(right_col))
262                                || (right_schema.has_column(left_col)
263                                    && right_schema.has_column(right_col))
264                            {
265                                non_nullable_cols.push(left_col.clone());
266                                break;
267                            }
268                        }
269                    }
270                }
271            }
272            _ => {}
273        },
274        Expr::Not(arg) => extract_non_nullable_columns(
275            arg,
276            non_nullable_cols,
277            left_schema,
278            right_schema,
279            false,
280        ),
281        Expr::IsNotNull(arg) => {
282            if !top_level {
283                return;
284            }
285            extract_non_nullable_columns(
286                arg,
287                non_nullable_cols,
288                left_schema,
289                right_schema,
290                false,
291            )
292        }
293        Expr::Cast(Cast { expr, data_type: _ })
294        | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns(
295            expr,
296            non_nullable_cols,
297            left_schema,
298            right_schema,
299            false,
300        ),
301        _ => {}
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::OptimizerContext;
309    use crate::assert_optimized_plan_eq_snapshot;
310    use crate::test::*;
311    use arrow::datatypes::DataType;
312    use datafusion_expr::{
313        Operator::{And, Or},
314        binary_expr, cast, col, lit,
315        logical_plan::builder::LogicalPlanBuilder,
316        try_cast,
317    };
318
319    macro_rules! assert_optimized_plan_equal {
320        (
321            $plan:expr,
322            @ $expected:literal $(,)?
323        ) => {{
324            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
325            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateOuterJoin::new())];
326            assert_optimized_plan_eq_snapshot!(
327                optimizer_ctx,
328                rules,
329                $plan,
330                @ $expected,
331            )
332        }};
333    }
334
335    #[test]
336    fn eliminate_left_with_null() -> Result<()> {
337        let t1 = test_table_scan_with_name("t1")?;
338        let t2 = test_table_scan_with_name("t2")?;
339
340        // could not eliminate to inner join
341        let plan = LogicalPlanBuilder::from(t1)
342            .join(
343                t2,
344                JoinType::Left,
345                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
346                None,
347            )?
348            .filter(col("t2.b").is_null())?
349            .build()?;
350
351        assert_optimized_plan_equal!(plan, @r"
352        Filter: t2.b IS NULL
353          Left Join: t1.a = t2.a
354            TableScan: t1
355            TableScan: t2
356        ")
357    }
358
359    #[test]
360    fn eliminate_left_with_not_null() -> Result<()> {
361        let t1 = test_table_scan_with_name("t1")?;
362        let t2 = test_table_scan_with_name("t2")?;
363
364        // eliminate to inner join
365        let plan = LogicalPlanBuilder::from(t1)
366            .join(
367                t2,
368                JoinType::Left,
369                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
370                None,
371            )?
372            .filter(col("t2.b").is_not_null())?
373            .build()?;
374
375        assert_optimized_plan_equal!(plan, @r"
376        Filter: t2.b IS NOT NULL
377          Inner Join: t1.a = t2.a
378            TableScan: t1
379            TableScan: t2
380        ")
381    }
382
383    #[test]
384    fn eliminate_right_with_or() -> Result<()> {
385        let t1 = test_table_scan_with_name("t1")?;
386        let t2 = test_table_scan_with_name("t2")?;
387
388        // eliminate to inner join
389        let plan = LogicalPlanBuilder::from(t1)
390            .join(
391                t2,
392                JoinType::Right,
393                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
394                None,
395            )?
396            .filter(binary_expr(
397                col("t1.b").gt(lit(10u32)),
398                Or,
399                col("t1.c").lt(lit(20u32)),
400            ))?
401            .build()?;
402
403        assert_optimized_plan_equal!(plan, @r"
404        Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)
405          Inner Join: t1.a = t2.a
406            TableScan: t1
407            TableScan: t2
408        ")
409    }
410
411    #[test]
412    fn eliminate_full_with_and() -> Result<()> {
413        let t1 = test_table_scan_with_name("t1")?;
414        let t2 = test_table_scan_with_name("t2")?;
415
416        // eliminate to inner join
417        let plan = LogicalPlanBuilder::from(t1)
418            .join(
419                t2,
420                JoinType::Full,
421                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
422                None,
423            )?
424            .filter(binary_expr(
425                col("t1.b").gt(lit(10u32)),
426                And,
427                col("t2.c").lt(lit(20u32)),
428            ))?
429            .build()?;
430
431        assert_optimized_plan_equal!(plan, @r"
432        Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)
433          Inner Join: t1.a = t2.a
434            TableScan: t1
435            TableScan: t2
436        ")
437    }
438
439    #[test]
440    fn eliminate_full_with_type_cast() -> Result<()> {
441        let t1 = test_table_scan_with_name("t1")?;
442        let t2 = test_table_scan_with_name("t2")?;
443
444        // eliminate to inner join
445        let plan = LogicalPlanBuilder::from(t1)
446            .join(
447                t2,
448                JoinType::Full,
449                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
450                None,
451            )?
452            .filter(binary_expr(
453                cast(col("t1.b"), DataType::Int64).gt(lit(10u32)),
454                And,
455                try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)),
456            ))?
457            .build()?;
458
459        assert_optimized_plan_equal!(plan, @r"
460        Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)
461          Inner Join: t1.a = t2.a
462            TableScan: t1
463            TableScan: t2
464        ")
465    }
466}