1use crate::{OptimizerConfig, OptimizerRule};
20use datafusion_common::{Column, DFSchema, Result};
21use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan};
22use datafusion_expr::{Expr, Filter, Operator};
23
24use crate::optimizer::ApplyOrder;
25use datafusion_common::tree_node::Transformed;
26use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
27use std::sync::Arc;
28
29#[derive(Default, Debug)]
52pub struct EliminateOuterJoin;
53
54impl EliminateOuterJoin {
55 #[allow(missing_docs)]
56 pub fn new() -> Self {
57 Self {}
58 }
59}
60
61impl OptimizerRule for EliminateOuterJoin {
63 fn name(&self) -> &str {
64 "eliminate_outer_join"
65 }
66
67 fn apply_order(&self) -> Option<ApplyOrder> {
68 Some(ApplyOrder::TopDown)
69 }
70
71 fn supports_rewrite(&self) -> bool {
72 true
73 }
74
75 fn rewrite(
76 &self,
77 plan: LogicalPlan,
78 _config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 match plan {
81 LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) {
82 LogicalPlan::Join(join) => {
83 let mut non_nullable_cols: Vec<Column> = vec![];
84
85 extract_non_nullable_columns(
86 &filter.predicate,
87 &mut non_nullable_cols,
88 join.left.schema(),
89 join.right.schema(),
90 true,
91 );
92
93 let new_join_type = if join.join_type.is_outer() {
94 let mut left_non_nullable = false;
95 let mut right_non_nullable = false;
96 for col in non_nullable_cols.iter() {
97 if join.left.schema().has_column(col) {
98 left_non_nullable = true;
99 }
100 if join.right.schema().has_column(col) {
101 right_non_nullable = true;
102 }
103 }
104 eliminate_outer(
105 join.join_type,
106 left_non_nullable,
107 right_non_nullable,
108 )
109 } else {
110 join.join_type
111 };
112
113 let new_join = Arc::new(LogicalPlan::Join(Join {
114 left: join.left,
115 right: join.right,
116 join_type: new_join_type,
117 join_constraint: join.join_constraint,
118 on: join.on.clone(),
119 filter: join.filter.clone(),
120 schema: Arc::clone(&join.schema),
121 null_equals_null: join.null_equals_null,
122 }));
123 Filter::try_new(filter.predicate, new_join)
124 .map(|f| Transformed::yes(LogicalPlan::Filter(f)))
125 }
126 filter_input => {
127 filter.input = Arc::new(filter_input);
128 Ok(Transformed::no(LogicalPlan::Filter(filter)))
129 }
130 },
131 _ => Ok(Transformed::no(plan)),
132 }
133 }
134}
135
136pub fn eliminate_outer(
137 join_type: JoinType,
138 left_non_nullable: bool,
139 right_non_nullable: bool,
140) -> JoinType {
141 let mut new_join_type = join_type;
142 match join_type {
143 JoinType::Left => {
144 if right_non_nullable {
145 new_join_type = JoinType::Inner;
146 }
147 }
148 JoinType::Right => {
149 if left_non_nullable {
150 new_join_type = JoinType::Inner;
151 }
152 }
153 JoinType::Full => {
154 if left_non_nullable && right_non_nullable {
155 new_join_type = JoinType::Inner;
156 } else if left_non_nullable {
157 new_join_type = JoinType::Left;
158 } else if right_non_nullable {
159 new_join_type = JoinType::Right;
160 }
161 }
162 _ => {}
163 }
164 new_join_type
165}
166
167fn extract_non_nullable_columns(
176 expr: &Expr,
177 non_nullable_cols: &mut Vec<Column>,
178 left_schema: &Arc<DFSchema>,
179 right_schema: &Arc<DFSchema>,
180 top_level: bool,
181) {
182 match expr {
183 Expr::Column(col) => {
184 non_nullable_cols.push(col.clone());
185 }
186 Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
187 Operator::Eq
189 | Operator::NotEq
190 | Operator::Lt
191 | Operator::LtEq
192 | Operator::Gt
193 | Operator::GtEq => {
194 extract_non_nullable_columns(
195 left,
196 non_nullable_cols,
197 left_schema,
198 right_schema,
199 false,
200 );
201 extract_non_nullable_columns(
202 right,
203 non_nullable_cols,
204 left_schema,
205 right_schema,
206 false,
207 )
208 }
209 Operator::And | Operator::Or => {
210 if top_level && *op == Operator::And {
213 extract_non_nullable_columns(
214 left,
215 non_nullable_cols,
216 left_schema,
217 right_schema,
218 top_level,
219 );
220 extract_non_nullable_columns(
221 right,
222 non_nullable_cols,
223 left_schema,
224 right_schema,
225 top_level,
226 );
227 return;
228 }
229
230 let mut left_non_nullable_cols: Vec<Column> = vec![];
231 let mut right_non_nullable_cols: Vec<Column> = vec![];
232
233 extract_non_nullable_columns(
234 left,
235 &mut left_non_nullable_cols,
236 left_schema,
237 right_schema,
238 top_level,
239 );
240 extract_non_nullable_columns(
241 right,
242 &mut right_non_nullable_cols,
243 left_schema,
244 right_schema,
245 top_level,
246 );
247
248 if !left_non_nullable_cols.is_empty()
255 && !right_non_nullable_cols.is_empty()
256 {
257 for left_col in &left_non_nullable_cols {
258 for right_col in &right_non_nullable_cols {
259 if (left_schema.has_column(left_col)
260 && left_schema.has_column(right_col))
261 || (right_schema.has_column(left_col)
262 && right_schema.has_column(right_col))
263 {
264 non_nullable_cols.push(left_col.clone());
265 break;
266 }
267 }
268 }
269 }
270 }
271 _ => {}
272 },
273 Expr::Not(arg) => extract_non_nullable_columns(
274 arg,
275 non_nullable_cols,
276 left_schema,
277 right_schema,
278 false,
279 ),
280 Expr::IsNotNull(arg) => {
281 if !top_level {
282 return;
283 }
284 extract_non_nullable_columns(
285 arg,
286 non_nullable_cols,
287 left_schema,
288 right_schema,
289 false,
290 )
291 }
292 Expr::Cast(Cast { expr, data_type: _ })
293 | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns(
294 expr,
295 non_nullable_cols,
296 left_schema,
297 right_schema,
298 false,
299 ),
300 _ => {}
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::test::*;
308 use arrow::datatypes::DataType;
309 use datafusion_expr::{
310 binary_expr, cast, col, lit,
311 logical_plan::builder::LogicalPlanBuilder,
312 try_cast,
313 Operator::{And, Or},
314 };
315
316 fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
317 assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected)
318 }
319
320 #[test]
321 fn eliminate_left_with_null() -> Result<()> {
322 let t1 = test_table_scan_with_name("t1")?;
323 let t2 = test_table_scan_with_name("t2")?;
324
325 let plan = LogicalPlanBuilder::from(t1)
327 .join(
328 t2,
329 JoinType::Left,
330 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
331 None,
332 )?
333 .filter(col("t2.b").is_null())?
334 .build()?;
335 let expected = "\
336 Filter: t2.b IS NULL\
337 \n Left Join: t1.a = t2.a\
338 \n TableScan: t1\
339 \n TableScan: t2";
340 assert_optimized_plan_equal(plan, expected)
341 }
342
343 #[test]
344 fn eliminate_left_with_not_null() -> Result<()> {
345 let t1 = test_table_scan_with_name("t1")?;
346 let t2 = test_table_scan_with_name("t2")?;
347
348 let plan = LogicalPlanBuilder::from(t1)
350 .join(
351 t2,
352 JoinType::Left,
353 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
354 None,
355 )?
356 .filter(col("t2.b").is_not_null())?
357 .build()?;
358 let expected = "\
359 Filter: t2.b IS NOT NULL\
360 \n Inner Join: t1.a = t2.a\
361 \n TableScan: t1\
362 \n TableScan: t2";
363 assert_optimized_plan_equal(plan, expected)
364 }
365
366 #[test]
367 fn eliminate_right_with_or() -> Result<()> {
368 let t1 = test_table_scan_with_name("t1")?;
369 let t2 = test_table_scan_with_name("t2")?;
370
371 let plan = LogicalPlanBuilder::from(t1)
373 .join(
374 t2,
375 JoinType::Right,
376 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
377 None,
378 )?
379 .filter(binary_expr(
380 col("t1.b").gt(lit(10u32)),
381 Or,
382 col("t1.c").lt(lit(20u32)),
383 ))?
384 .build()?;
385 let expected = "\
386 Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)\
387 \n Inner Join: t1.a = t2.a\
388 \n TableScan: t1\
389 \n TableScan: t2";
390 assert_optimized_plan_equal(plan, expected)
391 }
392
393 #[test]
394 fn eliminate_full_with_and() -> Result<()> {
395 let t1 = test_table_scan_with_name("t1")?;
396 let t2 = test_table_scan_with_name("t2")?;
397
398 let plan = LogicalPlanBuilder::from(t1)
400 .join(
401 t2,
402 JoinType::Full,
403 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
404 None,
405 )?
406 .filter(binary_expr(
407 col("t1.b").gt(lit(10u32)),
408 And,
409 col("t2.c").lt(lit(20u32)),
410 ))?
411 .build()?;
412 let expected = "\
413 Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)\
414 \n Inner Join: t1.a = t2.a\
415 \n TableScan: t1\
416 \n TableScan: t2";
417 assert_optimized_plan_equal(plan, expected)
418 }
419
420 #[test]
421 fn eliminate_full_with_type_cast() -> Result<()> {
422 let t1 = test_table_scan_with_name("t1")?;
423 let t2 = test_table_scan_with_name("t2")?;
424
425 let plan = LogicalPlanBuilder::from(t1)
427 .join(
428 t2,
429 JoinType::Full,
430 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
431 None,
432 )?
433 .filter(binary_expr(
434 cast(col("t1.b"), DataType::Int64).gt(lit(10u32)),
435 And,
436 try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)),
437 ))?
438 .build()?;
439 let expected = "\
440 Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)\
441 \n Inner Join: t1.a = t2.a\
442 \n TableScan: t1\
443 \n TableScan: t2";
444 assert_optimized_plan_equal(plan, expected)
445 }
446}