1use crate::{OptimizerConfig, OptimizerRule};
20use datafusion_common::{Column, DFSchema, Result};
21use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan};
22use datafusion_expr::{Expr, Filter, Operator};
23
24use crate::optimizer::ApplyOrder;
25use datafusion_common::tree_node::Transformed;
26use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
27use std::sync::Arc;
28
29#[derive(Default, Debug)]
52pub struct EliminateOuterJoin;
53
54impl EliminateOuterJoin {
55 #[expect(missing_docs)]
56 pub fn new() -> Self {
57 Self {}
58 }
59}
60
61impl OptimizerRule for EliminateOuterJoin {
63 fn name(&self) -> &str {
64 "eliminate_outer_join"
65 }
66
67 fn apply_order(&self) -> Option<ApplyOrder> {
68 Some(ApplyOrder::TopDown)
69 }
70
71 fn supports_rewrite(&self) -> bool {
72 true
73 }
74
75 fn rewrite(
76 &self,
77 plan: LogicalPlan,
78 _config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 match plan {
81 LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) {
82 LogicalPlan::Join(join) => {
83 let mut non_nullable_cols: Vec<Column> = vec![];
84
85 extract_non_nullable_columns(
86 &filter.predicate,
87 &mut non_nullable_cols,
88 join.left.schema(),
89 join.right.schema(),
90 true,
91 );
92
93 let new_join_type = if join.join_type.is_outer() {
94 let mut left_non_nullable = false;
95 let mut right_non_nullable = false;
96 for col in non_nullable_cols.iter() {
97 if join.left.schema().has_column(col) {
98 left_non_nullable = true;
99 }
100 if join.right.schema().has_column(col) {
101 right_non_nullable = true;
102 }
103 }
104 eliminate_outer(
105 join.join_type,
106 left_non_nullable,
107 right_non_nullable,
108 )
109 } else {
110 join.join_type
111 };
112
113 let new_join = Arc::new(LogicalPlan::Join(Join {
114 left: join.left,
115 right: join.right,
116 join_type: new_join_type,
117 join_constraint: join.join_constraint,
118 on: join.on.clone(),
119 filter: join.filter.clone(),
120 schema: Arc::clone(&join.schema),
121 null_equality: join.null_equality,
122 null_aware: join.null_aware,
123 }));
124 Filter::try_new(filter.predicate, new_join)
125 .map(|f| Transformed::yes(LogicalPlan::Filter(f)))
126 }
127 filter_input => {
128 filter.input = Arc::new(filter_input);
129 Ok(Transformed::no(LogicalPlan::Filter(filter)))
130 }
131 },
132 _ => Ok(Transformed::no(plan)),
133 }
134 }
135}
136
137pub fn eliminate_outer(
138 join_type: JoinType,
139 left_non_nullable: bool,
140 right_non_nullable: bool,
141) -> JoinType {
142 let mut new_join_type = join_type;
143 match join_type {
144 JoinType::Left => {
145 if right_non_nullable {
146 new_join_type = JoinType::Inner;
147 }
148 }
149 JoinType::Right => {
150 if left_non_nullable {
151 new_join_type = JoinType::Inner;
152 }
153 }
154 JoinType::Full => {
155 if left_non_nullable && right_non_nullable {
156 new_join_type = JoinType::Inner;
157 } else if left_non_nullable {
158 new_join_type = JoinType::Left;
159 } else if right_non_nullable {
160 new_join_type = JoinType::Right;
161 }
162 }
163 _ => {}
164 }
165 new_join_type
166}
167
168fn extract_non_nullable_columns(
177 expr: &Expr,
178 non_nullable_cols: &mut Vec<Column>,
179 left_schema: &Arc<DFSchema>,
180 right_schema: &Arc<DFSchema>,
181 top_level: bool,
182) {
183 match expr {
184 Expr::Column(col) => {
185 non_nullable_cols.push(col.clone());
186 }
187 Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
188 Operator::Eq
190 | Operator::NotEq
191 | Operator::Lt
192 | Operator::LtEq
193 | Operator::Gt
194 | Operator::GtEq => {
195 extract_non_nullable_columns(
196 left,
197 non_nullable_cols,
198 left_schema,
199 right_schema,
200 false,
201 );
202 extract_non_nullable_columns(
203 right,
204 non_nullable_cols,
205 left_schema,
206 right_schema,
207 false,
208 )
209 }
210 Operator::And | Operator::Or => {
211 if top_level && *op == Operator::And {
214 extract_non_nullable_columns(
215 left,
216 non_nullable_cols,
217 left_schema,
218 right_schema,
219 top_level,
220 );
221 extract_non_nullable_columns(
222 right,
223 non_nullable_cols,
224 left_schema,
225 right_schema,
226 top_level,
227 );
228 return;
229 }
230
231 let mut left_non_nullable_cols: Vec<Column> = vec![];
232 let mut right_non_nullable_cols: Vec<Column> = vec![];
233
234 extract_non_nullable_columns(
235 left,
236 &mut left_non_nullable_cols,
237 left_schema,
238 right_schema,
239 top_level,
240 );
241 extract_non_nullable_columns(
242 right,
243 &mut right_non_nullable_cols,
244 left_schema,
245 right_schema,
246 top_level,
247 );
248
249 if !left_non_nullable_cols.is_empty()
256 && !right_non_nullable_cols.is_empty()
257 {
258 for left_col in &left_non_nullable_cols {
259 for right_col in &right_non_nullable_cols {
260 if (left_schema.has_column(left_col)
261 && left_schema.has_column(right_col))
262 || (right_schema.has_column(left_col)
263 && right_schema.has_column(right_col))
264 {
265 non_nullable_cols.push(left_col.clone());
266 break;
267 }
268 }
269 }
270 }
271 }
272 _ => {}
273 },
274 Expr::Not(arg) => extract_non_nullable_columns(
275 arg,
276 non_nullable_cols,
277 left_schema,
278 right_schema,
279 false,
280 ),
281 Expr::IsNotNull(arg) => {
282 if !top_level {
283 return;
284 }
285 extract_non_nullable_columns(
286 arg,
287 non_nullable_cols,
288 left_schema,
289 right_schema,
290 false,
291 )
292 }
293 Expr::Cast(Cast { expr, data_type: _ })
294 | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns(
295 expr,
296 non_nullable_cols,
297 left_schema,
298 right_schema,
299 false,
300 ),
301 _ => {}
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::OptimizerContext;
309 use crate::assert_optimized_plan_eq_snapshot;
310 use crate::test::*;
311 use arrow::datatypes::DataType;
312 use datafusion_expr::{
313 Operator::{And, Or},
314 binary_expr, cast, col, lit,
315 logical_plan::builder::LogicalPlanBuilder,
316 try_cast,
317 };
318
319 macro_rules! assert_optimized_plan_equal {
320 (
321 $plan:expr,
322 @ $expected:literal $(,)?
323 ) => {{
324 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
325 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(EliminateOuterJoin::new())];
326 assert_optimized_plan_eq_snapshot!(
327 optimizer_ctx,
328 rules,
329 $plan,
330 @ $expected,
331 )
332 }};
333 }
334
335 #[test]
336 fn eliminate_left_with_null() -> Result<()> {
337 let t1 = test_table_scan_with_name("t1")?;
338 let t2 = test_table_scan_with_name("t2")?;
339
340 let plan = LogicalPlanBuilder::from(t1)
342 .join(
343 t2,
344 JoinType::Left,
345 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
346 None,
347 )?
348 .filter(col("t2.b").is_null())?
349 .build()?;
350
351 assert_optimized_plan_equal!(plan, @r"
352 Filter: t2.b IS NULL
353 Left Join: t1.a = t2.a
354 TableScan: t1
355 TableScan: t2
356 ")
357 }
358
359 #[test]
360 fn eliminate_left_with_not_null() -> Result<()> {
361 let t1 = test_table_scan_with_name("t1")?;
362 let t2 = test_table_scan_with_name("t2")?;
363
364 let plan = LogicalPlanBuilder::from(t1)
366 .join(
367 t2,
368 JoinType::Left,
369 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
370 None,
371 )?
372 .filter(col("t2.b").is_not_null())?
373 .build()?;
374
375 assert_optimized_plan_equal!(plan, @r"
376 Filter: t2.b IS NOT NULL
377 Inner Join: t1.a = t2.a
378 TableScan: t1
379 TableScan: t2
380 ")
381 }
382
383 #[test]
384 fn eliminate_right_with_or() -> Result<()> {
385 let t1 = test_table_scan_with_name("t1")?;
386 let t2 = test_table_scan_with_name("t2")?;
387
388 let plan = LogicalPlanBuilder::from(t1)
390 .join(
391 t2,
392 JoinType::Right,
393 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
394 None,
395 )?
396 .filter(binary_expr(
397 col("t1.b").gt(lit(10u32)),
398 Or,
399 col("t1.c").lt(lit(20u32)),
400 ))?
401 .build()?;
402
403 assert_optimized_plan_equal!(plan, @r"
404 Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)
405 Inner Join: t1.a = t2.a
406 TableScan: t1
407 TableScan: t2
408 ")
409 }
410
411 #[test]
412 fn eliminate_full_with_and() -> Result<()> {
413 let t1 = test_table_scan_with_name("t1")?;
414 let t2 = test_table_scan_with_name("t2")?;
415
416 let plan = LogicalPlanBuilder::from(t1)
418 .join(
419 t2,
420 JoinType::Full,
421 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
422 None,
423 )?
424 .filter(binary_expr(
425 col("t1.b").gt(lit(10u32)),
426 And,
427 col("t2.c").lt(lit(20u32)),
428 ))?
429 .build()?;
430
431 assert_optimized_plan_equal!(plan, @r"
432 Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)
433 Inner Join: t1.a = t2.a
434 TableScan: t1
435 TableScan: t2
436 ")
437 }
438
439 #[test]
440 fn eliminate_full_with_type_cast() -> Result<()> {
441 let t1 = test_table_scan_with_name("t1")?;
442 let t2 = test_table_scan_with_name("t2")?;
443
444 let plan = LogicalPlanBuilder::from(t1)
446 .join(
447 t2,
448 JoinType::Full,
449 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
450 None,
451 )?
452 .filter(binary_expr(
453 cast(col("t1.b"), DataType::Int64).gt(lit(10u32)),
454 And,
455 try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)),
456 ))?
457 .build()?;
458
459 assert_optimized_plan_equal!(plan, @r"
460 Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)
461 Inner Join: t1.a = t2.a
462 TableScan: t1
463 TableScan: t2
464 ")
465 }
466}