1use crate::{OptimizerConfig, OptimizerRule};
20use datafusion_common::{Column, DFSchema, Result};
21use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan};
22use datafusion_expr::{Expr, Filter, Operator};
23
24use crate::optimizer::ApplyOrder;
25use datafusion_common::tree_node::Transformed;
26use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
27use std::sync::Arc;
28
29#[derive(Default, Debug)]
52pub struct EliminateOuterJoin;
53
54impl EliminateOuterJoin {
55 #[allow(missing_docs)]
56 pub fn new() -> Self {
57 Self {}
58 }
59}
60
61impl OptimizerRule for EliminateOuterJoin {
63 fn name(&self) -> &str {
64 "eliminate_outer_join"
65 }
66
67 fn apply_order(&self) -> Option<ApplyOrder> {
68 Some(ApplyOrder::TopDown)
69 }
70
71 fn supports_rewrite(&self) -> bool {
72 true
73 }
74
75 fn rewrite(
76 &self,
77 plan: LogicalPlan,
78 _config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 match plan {
81 LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) {
82 LogicalPlan::Join(join) => {
83 let mut non_nullable_cols: Vec<Column> = vec![];
84
85 extract_non_nullable_columns(
86 &filter.predicate,
87 &mut non_nullable_cols,
88 join.left.schema(),
89 join.right.schema(),
90 true,
91 );
92
93 let new_join_type = if join.join_type.is_outer() {
94 let mut left_non_nullable = false;
95 let mut right_non_nullable = false;
96 for col in non_nullable_cols.iter() {
97 if join.left.schema().has_column(col) {
98 left_non_nullable = true;
99 }
100 if join.right.schema().has_column(col) {
101 right_non_nullable = true;
102 }
103 }
104 eliminate_outer(
105 join.join_type,
106 left_non_nullable,
107 right_non_nullable,
108 )
109 } else {
110 join.join_type
111 };
112
113 let new_join = Arc::new(LogicalPlan::Join(Join {
114 left: join.left,
115 right: join.right,
116 join_type: new_join_type,
117 join_constraint: join.join_constraint,
118 on: join.on.clone(),
119 filter: join.filter.clone(),
120 schema: Arc::clone(&join.schema),
121 null_equality: join.null_equality,
122 }));
123 Filter::try_new(filter.predicate, new_join)
124 .map(|f| Transformed::yes(LogicalPlan::Filter(f)))
125 }
126 filter_input => {
127 filter.input = Arc::new(filter_input);
128 Ok(Transformed::no(LogicalPlan::Filter(filter)))
129 }
130 },
131 _ => Ok(Transformed::no(plan)),
132 }
133 }
134}
135
136pub fn eliminate_outer(
137 join_type: JoinType,
138 left_non_nullable: bool,
139 right_non_nullable: bool,
140) -> JoinType {
141 let mut new_join_type = join_type;
142 match join_type {
143 JoinType::Left => {
144 if right_non_nullable {
145 new_join_type = JoinType::Inner;
146 }
147 }
148 JoinType::Right => {
149 if left_non_nullable {
150 new_join_type = JoinType::Inner;
151 }
152 }
153 JoinType::Full => {
154 if left_non_nullable && right_non_nullable {
155 new_join_type = JoinType::Inner;
156 } else if left_non_nullable {
157 new_join_type = JoinType::Left;
158 } else if right_non_nullable {
159 new_join_type = JoinType::Right;
160 }
161 }
162 _ => {}
163 }
164 new_join_type
165}
166
167fn extract_non_nullable_columns(
176 expr: &Expr,
177 non_nullable_cols: &mut Vec<Column>,
178 left_schema: &Arc<DFSchema>,
179 right_schema: &Arc<DFSchema>,
180 top_level: bool,
181) {
182 match expr {
183 Expr::Column(col) => {
184 non_nullable_cols.push(col.clone());
185 }
186 Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
187 Operator::Eq
189 | Operator::NotEq
190 | Operator::Lt
191 | Operator::LtEq
192 | Operator::Gt
193 | Operator::GtEq => {
194 extract_non_nullable_columns(
195 left,
196 non_nullable_cols,
197 left_schema,
198 right_schema,
199 false,
200 );
201 extract_non_nullable_columns(
202 right,
203 non_nullable_cols,
204 left_schema,
205 right_schema,
206 false,
207 )
208 }
209 Operator::And | Operator::Or => {
210 if top_level && *op == Operator::And {
213 extract_non_nullable_columns(
214 left,
215 non_nullable_cols,
216 left_schema,
217 right_schema,
218 top_level,
219 );
220 extract_non_nullable_columns(
221 right,
222 non_nullable_cols,
223 left_schema,
224 right_schema,
225 top_level,
226 );
227 return;
228 }
229
230 let mut left_non_nullable_cols: Vec<Column> = vec![];
231 let mut right_non_nullable_cols: Vec<Column> = vec![];
232
233 extract_non_nullable_columns(
234 left,
235 &mut left_non_nullable_cols,
236 left_schema,
237 right_schema,
238 top_level,
239 );
240 extract_non_nullable_columns(
241 right,
242 &mut right_non_nullable_cols,
243 left_schema,
244 right_schema,
245 top_level,
246 );
247
248 if !left_non_nullable_cols.is_empty()
255 && !right_non_nullable_cols.is_empty()
256 {
257 for left_col in &left_non_nullable_cols {
258 for right_col in &right_non_nullable_cols {
259 if (left_schema.has_column(left_col)
260 && left_schema.has_column(right_col))
261 || (right_schema.has_column(left_col)
262 && right_schema.has_column(right_col))
263 {
264 non_nullable_cols.push(left_col.clone());
265 break;
266 }
267 }
268 }
269 }
270 }
271 _ => {}
272 },
273 Expr::Not(arg) => extract_non_nullable_columns(
274 arg,
275 non_nullable_cols,
276 left_schema,
277 right_schema,
278 false,
279 ),
280 Expr::IsNotNull(arg) => {
281 if !top_level {
282 return;
283 }
284 extract_non_nullable_columns(
285 arg,
286 non_nullable_cols,
287 left_schema,
288 right_schema,
289 false,
290 )
291 }
292 Expr::Cast(Cast { expr, data_type: _ })
293 | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns(
294 expr,
295 non_nullable_cols,
296 left_schema,
297 right_schema,
298 false,
299 ),
300 _ => {}
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::assert_optimized_plan_eq_snapshot;
308 use crate::test::*;
309 use crate::OptimizerContext;
310 use arrow::datatypes::DataType;
311 use datafusion_expr::{
312 binary_expr, cast, col, lit,
313 logical_plan::builder::LogicalPlanBuilder,
314 try_cast,
315 Operator::{And, Or},
316 };
317
318 macro_rules! assert_optimized_plan_equal {
319 (
320 $plan:expr,
321 @ $expected:literal $(,)?
322 ) => {{
323 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
324 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateOuterJoin::new())];
325 assert_optimized_plan_eq_snapshot!(
326 optimizer_ctx,
327 rules,
328 $plan,
329 @ $expected,
330 )
331 }};
332 }
333
334 #[test]
335 fn eliminate_left_with_null() -> Result<()> {
336 let t1 = test_table_scan_with_name("t1")?;
337 let t2 = test_table_scan_with_name("t2")?;
338
339 let plan = LogicalPlanBuilder::from(t1)
341 .join(
342 t2,
343 JoinType::Left,
344 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
345 None,
346 )?
347 .filter(col("t2.b").is_null())?
348 .build()?;
349
350 assert_optimized_plan_equal!(plan, @r"
351 Filter: t2.b IS NULL
352 Left Join: t1.a = t2.a
353 TableScan: t1
354 TableScan: t2
355 ")
356 }
357
358 #[test]
359 fn eliminate_left_with_not_null() -> Result<()> {
360 let t1 = test_table_scan_with_name("t1")?;
361 let t2 = test_table_scan_with_name("t2")?;
362
363 let plan = LogicalPlanBuilder::from(t1)
365 .join(
366 t2,
367 JoinType::Left,
368 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
369 None,
370 )?
371 .filter(col("t2.b").is_not_null())?
372 .build()?;
373
374 assert_optimized_plan_equal!(plan, @r"
375 Filter: t2.b IS NOT NULL
376 Inner Join: t1.a = t2.a
377 TableScan: t1
378 TableScan: t2
379 ")
380 }
381
382 #[test]
383 fn eliminate_right_with_or() -> Result<()> {
384 let t1 = test_table_scan_with_name("t1")?;
385 let t2 = test_table_scan_with_name("t2")?;
386
387 let plan = LogicalPlanBuilder::from(t1)
389 .join(
390 t2,
391 JoinType::Right,
392 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
393 None,
394 )?
395 .filter(binary_expr(
396 col("t1.b").gt(lit(10u32)),
397 Or,
398 col("t1.c").lt(lit(20u32)),
399 ))?
400 .build()?;
401
402 assert_optimized_plan_equal!(plan, @r"
403 Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)
404 Inner Join: t1.a = t2.a
405 TableScan: t1
406 TableScan: t2
407 ")
408 }
409
410 #[test]
411 fn eliminate_full_with_and() -> Result<()> {
412 let t1 = test_table_scan_with_name("t1")?;
413 let t2 = test_table_scan_with_name("t2")?;
414
415 let plan = LogicalPlanBuilder::from(t1)
417 .join(
418 t2,
419 JoinType::Full,
420 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
421 None,
422 )?
423 .filter(binary_expr(
424 col("t1.b").gt(lit(10u32)),
425 And,
426 col("t2.c").lt(lit(20u32)),
427 ))?
428 .build()?;
429
430 assert_optimized_plan_equal!(plan, @r"
431 Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)
432 Inner Join: t1.a = t2.a
433 TableScan: t1
434 TableScan: t2
435 ")
436 }
437
438 #[test]
439 fn eliminate_full_with_type_cast() -> Result<()> {
440 let t1 = test_table_scan_with_name("t1")?;
441 let t2 = test_table_scan_with_name("t2")?;
442
443 let plan = LogicalPlanBuilder::from(t1)
445 .join(
446 t2,
447 JoinType::Full,
448 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
449 None,
450 )?
451 .filter(binary_expr(
452 cast(col("t1.b"), DataType::Int64).gt(lit(10u32)),
453 And,
454 try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)),
455 ))?
456 .build()?;
457
458 assert_optimized_plan_equal!(plan, @r"
459 Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)
460 Inner Join: t1.a = t2.a
461 TableScan: t1
462 TableScan: t2
463 ")
464 }
465}