Skip to main content

datafusion_optimizer/
eliminate_cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27    Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{ExprSchemable, Operator, and, build_join_schema};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36    #[expect(missing_docs)]
37    pub fn new() -> Self {
38        Self {}
39    }
40}
41
42/// Eliminate cross joins by rewriting them to inner joins when possible.
43///
44/// # Example
45/// The initial plan for this query:
46/// ```sql
47/// select ... from a, b where a.x = b.y and b.xx = 100;
48/// ```
49///
50/// Looks like this:
51/// ```text
52/// Filter(a.x = b.y AND b.xx = 100)
53///  Cross Join
54///   TableScan a
55///   TableScan b
56/// ```
57///
58/// After the rule is applied, the plan will look like this:
59/// ```text
60/// Filter(b.xx = 100)
61///   InnerJoin(a.x = b.y)
62///     TableScan a
63///     TableScan b
64/// ```
65///
66/// # Other Examples
67/// * 'select ... from a, b where a.x = b.y and b.xx = 100;'
68/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
69/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
70/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);'
71/// * 'select ... from a, b where a.x > b.y'
72///
73/// For above queries, the join predicate is available in filters and they are moved to
74/// join nodes appropriately
75///
76/// This fix helps to improve the performance of TPCH Q19. issue#78
77impl OptimizerRule for EliminateCrossJoin {
78    fn supports_rewrite(&self) -> bool {
79        true
80    }
81
82    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83    fn rewrite(
84        &self,
85        plan: LogicalPlan,
86        config: &dyn OptimizerConfig,
87    ) -> Result<Transformed<LogicalPlan>> {
88        let plan_schema = Arc::clone(plan.schema());
89        let mut possible_join_keys = JoinKeySet::new();
90        let mut all_inputs: Vec<LogicalPlan> = vec![];
91        let mut all_filters: Vec<Expr> = vec![];
92        let mut null_equality = NullEquality::NullEqualsNothing;
93
94        let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
95            // if input isn't a join that can potentially be rewritten
96            // avoid unwrapping the input
97            let rewritable = matches!(
98                filter.input.as_ref(),
99                LogicalPlan::Join(Join {
100                    join_type: JoinType::Inner,
101                    ..
102                })
103            );
104
105            if !rewritable {
106                // recursively try to rewrite children
107                return rewrite_children(self, LogicalPlan::Filter(filter), config);
108            }
109
110            if !can_flatten_join_inputs(&filter.input) {
111                return Ok(Transformed::no(LogicalPlan::Filter(filter)));
112            }
113
114            let Filter {
115                input, predicate, ..
116            } = filter;
117
118            // Extract null_equality setting from the input join
119            if let LogicalPlan::Join(join) = input.as_ref() {
120                null_equality = join.null_equality;
121            }
122
123            flatten_join_inputs(
124                Arc::unwrap_or_clone(input),
125                &mut possible_join_keys,
126                &mut all_inputs,
127                &mut all_filters,
128            )?;
129
130            extract_possible_join_keys(&predicate, &mut possible_join_keys);
131            Some(predicate)
132        } else {
133            match plan {
134                LogicalPlan::Join(Join {
135                    join_type: JoinType::Inner,
136                    null_equality: original_null_equality,
137                    ..
138                }) => {
139                    if !can_flatten_join_inputs(&plan) {
140                        return Ok(Transformed::no(plan));
141                    }
142                    flatten_join_inputs(
143                        plan,
144                        &mut possible_join_keys,
145                        &mut all_inputs,
146                        &mut all_filters,
147                    )?;
148                    null_equality = original_null_equality;
149                    None
150                }
151                _ => {
152                    // recursively try to rewrite children
153                    return rewrite_children(self, plan, config);
154                }
155            }
156        };
157
158        // Join keys are handled locally:
159        let mut all_join_keys = JoinKeySet::new();
160        let mut left = all_inputs.remove(0);
161        while !all_inputs.is_empty() {
162            left = find_inner_join(
163                left,
164                &mut all_inputs,
165                &possible_join_keys,
166                &mut all_join_keys,
167                null_equality,
168            )?;
169        }
170
171        left = rewrite_children(self, left, config)?.data;
172
173        if &plan_schema != left.schema() {
174            left = LogicalPlan::Projection(Projection::new_from_schema(
175                Arc::new(left),
176                Arc::clone(&plan_schema),
177            ));
178        }
179
180        if !all_filters.is_empty() {
181            // Add any filters on top - PushDownFilter can push filters down to applicable join
182            let first = all_filters.swap_remove(0);
183            let predicate = all_filters.into_iter().fold(first, and);
184            left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
185        }
186
187        let Some(predicate) = parent_predicate else {
188            return Ok(Transformed::yes(left));
189        };
190
191        // If there are no join keys then do nothing:
192        if all_join_keys.is_empty() {
193            Filter::try_new(predicate, Arc::new(left))
194                .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
195        } else {
196            // Remove join expressions from filter:
197            match remove_join_expressions(predicate, &all_join_keys) {
198                Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
199                    .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
200                _ => Ok(Transformed::yes(left)),
201            }
202        }
203    }
204
205    fn name(&self) -> &str {
206        "eliminate_cross_join"
207    }
208}
209
210fn rewrite_children(
211    optimizer: &impl OptimizerRule,
212    plan: LogicalPlan,
213    config: &dyn OptimizerConfig,
214) -> Result<Transformed<LogicalPlan>> {
215    let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
216
217    // recompute schema if the plan was transformed
218    if transformed_plan.transformed {
219        transformed_plan.map_data(|plan| plan.recompute_schema())
220    } else {
221        Ok(transformed_plan)
222    }
223}
224
225/// Recursively accumulate possible_join_keys and inputs from inner joins
226/// (including cross joins).
227///
228/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
229/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
230/// possible_join_keys
231fn flatten_join_inputs(
232    plan: LogicalPlan,
233    possible_join_keys: &mut JoinKeySet,
234    all_inputs: &mut Vec<LogicalPlan>,
235    all_filters: &mut Vec<Expr>,
236) -> Result<()> {
237    match plan {
238        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
239            if let Some(filter) = join.filter {
240                all_filters.push(filter);
241            }
242            possible_join_keys.insert_all_owned(join.on);
243            flatten_join_inputs(
244                Arc::unwrap_or_clone(join.left),
245                possible_join_keys,
246                all_inputs,
247                all_filters,
248            )?;
249            flatten_join_inputs(
250                Arc::unwrap_or_clone(join.right),
251                possible_join_keys,
252                all_inputs,
253                all_filters,
254            )?;
255        }
256        _ => {
257            all_inputs.push(plan);
258        }
259    };
260    Ok(())
261}
262
263/// Returns true if the plan is a Join or Cross join could be flattened with
264/// `flatten_join_inputs`
265///
266/// Must stay in sync with `flatten_join_inputs`
267fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
268    // can only flatten inner / cross joins
269    match plan {
270        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
271        _ => return false,
272    };
273
274    for child in plan.inputs() {
275        if let LogicalPlan::Join(Join {
276            join_type: JoinType::Inner,
277            ..
278        }) = child
279            && !can_flatten_join_inputs(child)
280        {
281            return false;
282        }
283    }
284    true
285}
286
287/// Finds the next to join with the left input plan,
288///
289/// Finds the next `right` from `rights` that can be joined with `left_input`
290/// plan based on the join keys in `possible_join_keys`.
291///
292/// If such a matching `right` is found:
293/// 1. Adds the matching join keys to `all_join_keys`.
294/// 2. Returns `left_input JOIN right ON (all join keys)`.
295///
296/// If no matching `right` is found:
297/// 1. Removes the first plan from `rights`
298/// 2. Returns `left_input CROSS JOIN right`.
299fn find_inner_join(
300    left_input: LogicalPlan,
301    rights: &mut Vec<LogicalPlan>,
302    possible_join_keys: &JoinKeySet,
303    all_join_keys: &mut JoinKeySet,
304    null_equality: NullEquality,
305) -> Result<LogicalPlan> {
306    for (i, right_input) in rights.iter().enumerate() {
307        let mut join_keys = vec![];
308
309        for (l, r) in possible_join_keys.iter() {
310            let key_pair = find_valid_equijoin_key_pair(
311                l,
312                r,
313                left_input.schema(),
314                right_input.schema(),
315            )?;
316
317            // Save join keys
318            if let Some((valid_l, valid_r)) = key_pair
319                && can_hash(&valid_l.get_type(left_input.schema())?)
320            {
321                join_keys.push((valid_l, valid_r));
322            }
323        }
324
325        // Found one or more matching join keys
326        if !join_keys.is_empty() {
327            all_join_keys.insert_all(join_keys.iter());
328            let right_input = rights.remove(i);
329            let join_schema = Arc::new(build_join_schema(
330                left_input.schema(),
331                right_input.schema(),
332                &JoinType::Inner,
333            )?);
334
335            return Ok(LogicalPlan::Join(Join {
336                left: Arc::new(left_input),
337                right: Arc::new(right_input),
338                join_type: JoinType::Inner,
339                join_constraint: JoinConstraint::On,
340                on: join_keys,
341                filter: None,
342                schema: join_schema,
343                null_equality,
344                null_aware: false,
345            }));
346        }
347    }
348
349    // no matching right plan had any join keys, cross join with the first right
350    // plan
351    let right = rights.remove(0);
352    let join_schema = Arc::new(build_join_schema(
353        left_input.schema(),
354        right.schema(),
355        &JoinType::Inner,
356    )?);
357
358    Ok(LogicalPlan::Join(Join {
359        left: Arc::new(left_input),
360        right: Arc::new(right),
361        schema: join_schema,
362        on: vec![],
363        filter: None,
364        join_type: JoinType::Inner,
365        join_constraint: JoinConstraint::On,
366        null_equality,
367        null_aware: false,
368    }))
369}
370
371/// Extract join keys from a WHERE clause
372fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
373    if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
374        match op {
375            Operator::Eq => {
376                // insert handles ensuring  we don't add the same Join keys multiple times
377                join_keys.insert(left, right);
378            }
379            Operator::And => {
380                extract_possible_join_keys(left, join_keys);
381                extract_possible_join_keys(right, join_keys)
382            }
383            // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
384            Operator::Or => {
385                let mut left_join_keys = JoinKeySet::new();
386                let mut right_join_keys = JoinKeySet::new();
387
388                extract_possible_join_keys(left, &mut left_join_keys);
389                extract_possible_join_keys(right, &mut right_join_keys);
390
391                join_keys.insert_intersection(&left_join_keys, &right_join_keys)
392            }
393            _ => (),
394        };
395    }
396}
397
398/// Remove join expressions from a filter expression
399///
400/// # Returns
401/// * `Some()` when there are few remaining predicates in filter_expr
402/// * `None` otherwise
403fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
404    match expr {
405        Expr::BinaryExpr(BinaryExpr {
406            left,
407            op: Operator::Eq,
408            right,
409        }) if join_keys.contains(&left, &right) => {
410            // was a join key, so remove it
411            None
412        }
413        // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
414        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
415            let l = remove_join_expressions(*left, join_keys);
416            let r = remove_join_expressions(*right, join_keys);
417            match (l, r) {
418                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
419                    Box::new(ll),
420                    op,
421                    Box::new(rr),
422                ))),
423                (Some(ll), _) => Some(ll),
424                (_, Some(rr)) => Some(rr),
425                _ => None,
426            }
427        }
428        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
429            let l = remove_join_expressions(*left, join_keys);
430            let r = remove_join_expressions(*right, join_keys);
431            match (l, r) {
432                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
433                    Box::new(ll),
434                    op,
435                    Box::new(rr),
436                ))),
437                // When either `left` or `right` is empty, it means they are `true`
438                // so OR'ing anything with them will also be true
439                _ => None,
440            }
441        }
442        _ => Some(expr),
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use crate::optimizer::OptimizerContext;
450    use crate::test::*;
451
452    use datafusion_expr::{
453        Operator::{And, Or},
454        binary_expr, col, lit,
455        logical_plan::builder::LogicalPlanBuilder,
456    };
457    use insta::assert_snapshot;
458
459    macro_rules! assert_optimized_plan_equal {
460        (
461            $plan:expr,
462            @ $expected:literal $(,)?
463        ) => {{
464            let starting_schema = Arc::clone($plan.schema());
465            let rule = EliminateCrossJoin::new();
466            let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
467            let formatted_plan = optimized_plan.display_indent_schema();
468            // Ensure the rule was actually applied
469            assert!(is_plan_transformed, "failed to optimize plan");
470            // Verify the schema remains unchanged
471            assert_eq!(&starting_schema, optimized_plan.schema());
472            assert_snapshot!(
473                formatted_plan,
474                @ $expected,
475            );
476
477            Ok(())
478        }};
479    }
480
481    #[test]
482    fn eliminate_cross_with_simple_and() -> Result<()> {
483        let t1 = test_table_scan_with_name("t1")?;
484        let t2 = test_table_scan_with_name("t2")?;
485
486        // could eliminate to inner join since filter has Join predicates
487        let plan = LogicalPlanBuilder::from(t1)
488            .cross_join(t2)?
489            .filter(binary_expr(
490                col("t1.a").eq(col("t2.a")),
491                And,
492                col("t2.c").lt(lit(20u32)),
493            ))?
494            .build()?;
495
496        assert_optimized_plan_equal!(
497            plan,
498            @ r"
499        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
500          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
501            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
502            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
503        "
504        )
505    }
506
507    #[test]
508    fn eliminate_cross_with_simple_or() -> Result<()> {
509        let t1 = test_table_scan_with_name("t1")?;
510        let t2 = test_table_scan_with_name("t2")?;
511
512        // could not eliminate to inner join since filter OR expression and there is no common
513        // Join predicates in left and right of OR expr.
514        let plan = LogicalPlanBuilder::from(t1)
515            .cross_join(t2)?
516            .filter(binary_expr(
517                col("t1.a").eq(col("t2.a")),
518                Or,
519                col("t2.b").eq(col("t1.a")),
520            ))?
521            .build()?;
522
523        assert_optimized_plan_equal!(
524            plan,
525            @ r"
526        Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
527          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
528            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
529            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
530        "
531        )
532    }
533
534    #[test]
535    fn eliminate_cross_with_and() -> Result<()> {
536        let t1 = test_table_scan_with_name("t1")?;
537        let t2 = test_table_scan_with_name("t2")?;
538
539        // could eliminate to inner join
540        let plan = LogicalPlanBuilder::from(t1)
541            .cross_join(t2)?
542            .filter(binary_expr(
543                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
544                And,
545                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
546            ))?
547            .build()?;
548
549        assert_optimized_plan_equal!(
550            plan,
551            @ r"
552        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
553          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
554            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
555            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
556        "
557        )
558    }
559
560    #[test]
561    fn eliminate_cross_with_or() -> Result<()> {
562        let t1 = test_table_scan_with_name("t1")?;
563        let t2 = test_table_scan_with_name("t2")?;
564
565        // could eliminate to inner join since Or predicates have common Join predicates
566        let plan = LogicalPlanBuilder::from(t1)
567            .cross_join(t2)?
568            .filter(binary_expr(
569                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
570                Or,
571                binary_expr(
572                    col("t1.a").eq(col("t2.a")),
573                    And,
574                    col("t2.c").eq(lit(688u32)),
575                ),
576            ))?
577            .build()?;
578
579        assert_optimized_plan_equal!(
580            plan,
581            @ r"
582        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
583          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
584            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
585            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
586        "
587        )
588    }
589
590    #[test]
591    fn eliminate_cross_not_possible_simple() -> Result<()> {
592        let t1 = test_table_scan_with_name("t1")?;
593        let t2 = test_table_scan_with_name("t2")?;
594
595        // could not eliminate to inner join
596        let plan = LogicalPlanBuilder::from(t1)
597            .cross_join(t2)?
598            .filter(binary_expr(
599                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
600                Or,
601                binary_expr(
602                    col("t1.b").eq(col("t2.b")),
603                    And,
604                    col("t2.c").eq(lit(688u32)),
605                ),
606            ))?
607            .build()?;
608
609        assert_optimized_plan_equal!(
610            plan,
611            @ r"
612        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
613          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
614            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
615            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
616        "
617        )
618    }
619
620    #[test]
621    fn eliminate_cross_not_possible() -> Result<()> {
622        let t1 = test_table_scan_with_name("t1")?;
623        let t2 = test_table_scan_with_name("t2")?;
624
625        // could not eliminate to inner join
626        let plan = LogicalPlanBuilder::from(t1)
627            .cross_join(t2)?
628            .filter(binary_expr(
629                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
630                Or,
631                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
632            ))?
633            .build()?;
634
635        assert_optimized_plan_equal!(
636            plan,
637            @ r"
638        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
639          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
640            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
641            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
642        "
643        )
644    }
645
646    #[test]
647    fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
648        let t1 = test_table_scan_with_name("t1")?;
649        let t2 = test_table_scan_with_name("t2")?;
650        let t3 = test_table_scan_with_name("t3")?;
651
652        // could not eliminate to inner join with filter
653        let plan = LogicalPlanBuilder::from(t1)
654            .join(
655                t3,
656                JoinType::Inner,
657                (vec!["t1.a"], vec!["t3.a"]),
658                Some(col("t1.a").gt(lit(20u32))),
659            )?
660            .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
661            .filter(col("t1.a").gt(lit(15u32)))?
662            .build()?;
663
664        assert_optimized_plan_equal!(
665            plan,
666            @ r"
667        Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
668          Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
669            Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
670              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
671                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
672                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
673              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
674        "
675        )
676    }
677
678    #[test]
679    /// ```txt
680    /// filter: a.id = b.id and a.id = c.id
681    ///   cross_join a (bc)
682    ///     cross_join b c
683    /// ```
684    /// Without reorder, it will be
685    /// ```txt
686    ///   inner_join a (bc) on a.id = b.id and a.id = c.id
687    ///     cross_join b c
688    /// ```
689    /// Reorder it to be
690    /// ```txt
691    ///   inner_join (ab)c and a.id = c.id
692    ///     inner_join a b on a.id = b.id
693    /// ```
694    fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
695        let t1 = test_table_scan_with_name("t1")?;
696        let t2 = test_table_scan_with_name("t2")?;
697        let t3 = test_table_scan_with_name("t3")?;
698
699        // could eliminate to inner join
700        let plan = LogicalPlanBuilder::from(t1)
701            .cross_join(t2)?
702            .cross_join(t3)?
703            .filter(binary_expr(
704                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
705                And,
706                binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
707            ))?
708            .build()?;
709
710        assert_optimized_plan_equal!(
711            plan,
712            @ r"
713        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
714          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
715            Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
716              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
717                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
718                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
719              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
720        "
721        )
722    }
723
724    #[test]
725    fn eliminate_cross_join_multi_tables() -> Result<()> {
726        let t1 = test_table_scan_with_name("t1")?;
727        let t2 = test_table_scan_with_name("t2")?;
728        let t3 = test_table_scan_with_name("t3")?;
729        let t4 = test_table_scan_with_name("t4")?;
730
731        // could eliminate to inner join
732        let plan1 = LogicalPlanBuilder::from(t1)
733            .cross_join(t2)?
734            .filter(binary_expr(
735                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
736                Or,
737                binary_expr(
738                    col("t1.a").eq(col("t2.a")),
739                    And,
740                    col("t2.c").eq(lit(688u32)),
741                ),
742            ))?
743            .build()?;
744
745        let plan2 = LogicalPlanBuilder::from(t3)
746            .cross_join(t4)?
747            .filter(binary_expr(
748                binary_expr(
749                    binary_expr(
750                        col("t3.a").eq(col("t4.a")),
751                        And,
752                        col("t4.c").lt(lit(15u32)),
753                    ),
754                    Or,
755                    binary_expr(
756                        col("t3.a").eq(col("t4.a")),
757                        And,
758                        col("t3.c").eq(lit(688u32)),
759                    ),
760                ),
761                Or,
762                binary_expr(
763                    col("t3.a").eq(col("t4.a")),
764                    And,
765                    col("t3.b").eq(col("t4.b")),
766                ),
767            ))?
768            .build()?;
769
770        let plan = LogicalPlanBuilder::from(plan1)
771            .cross_join(plan2)?
772            .filter(binary_expr(
773                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
774                Or,
775                binary_expr(
776                    col("t3.a").eq(col("t1.a")),
777                    And,
778                    col("t4.c").eq(lit(688u32)),
779                ),
780            ))?
781            .build()?;
782
783        assert_optimized_plan_equal!(
784            plan,
785            @ r"
786        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
787          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
788            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
789              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
790                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
791                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
792            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
793              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
794                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
795                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
796        "
797        )
798    }
799
800    #[test]
801    fn eliminate_cross_join_multi_tables_1() -> Result<()> {
802        let t1 = test_table_scan_with_name("t1")?;
803        let t2 = test_table_scan_with_name("t2")?;
804        let t3 = test_table_scan_with_name("t3")?;
805        let t4 = test_table_scan_with_name("t4")?;
806
807        // could eliminate to inner join
808        let plan1 = LogicalPlanBuilder::from(t1)
809            .cross_join(t2)?
810            .filter(binary_expr(
811                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
812                Or,
813                binary_expr(
814                    col("t1.a").eq(col("t2.a")),
815                    And,
816                    col("t2.c").eq(lit(688u32)),
817                ),
818            ))?
819            .build()?;
820
821        // could eliminate to inner join
822        let plan2 = LogicalPlanBuilder::from(t3)
823            .cross_join(t4)?
824            .filter(binary_expr(
825                binary_expr(
826                    binary_expr(
827                        col("t3.a").eq(col("t4.a")),
828                        And,
829                        col("t4.c").lt(lit(15u32)),
830                    ),
831                    Or,
832                    binary_expr(
833                        col("t3.a").eq(col("t4.a")),
834                        And,
835                        col("t3.c").eq(lit(688u32)),
836                    ),
837                ),
838                Or,
839                binary_expr(
840                    col("t3.a").eq(col("t4.a")),
841                    And,
842                    col("t3.b").eq(col("t4.b")),
843                ),
844            ))?
845            .build()?;
846
847        // could not eliminate to inner join
848        let plan = LogicalPlanBuilder::from(plan1)
849            .cross_join(plan2)?
850            .filter(binary_expr(
851                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
852                Or,
853                binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
854            ))?
855            .build()?;
856
857        assert_optimized_plan_equal!(
858            plan,
859            @ r"
860        Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
861          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
862            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
863              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
864                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
865                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
866            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
867              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
868                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
869                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
870        "
871        )
872    }
873
874    #[test]
875    fn eliminate_cross_join_multi_tables_2() -> Result<()> {
876        let t1 = test_table_scan_with_name("t1")?;
877        let t2 = test_table_scan_with_name("t2")?;
878        let t3 = test_table_scan_with_name("t3")?;
879        let t4 = test_table_scan_with_name("t4")?;
880
881        // could eliminate to inner join
882        let plan1 = LogicalPlanBuilder::from(t1)
883            .cross_join(t2)?
884            .filter(binary_expr(
885                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
886                Or,
887                binary_expr(
888                    col("t1.a").eq(col("t2.a")),
889                    And,
890                    col("t2.c").eq(lit(688u32)),
891                ),
892            ))?
893            .build()?;
894
895        // could not eliminate to inner join
896        let plan2 = LogicalPlanBuilder::from(t3)
897            .cross_join(t4)?
898            .filter(binary_expr(
899                binary_expr(
900                    binary_expr(
901                        col("t3.a").eq(col("t4.a")),
902                        And,
903                        col("t4.c").lt(lit(15u32)),
904                    ),
905                    Or,
906                    binary_expr(
907                        col("t3.a").eq(col("t4.a")),
908                        And,
909                        col("t3.c").eq(lit(688u32)),
910                    ),
911                ),
912                Or,
913                binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
914            ))?
915            .build()?;
916
917        // could eliminate to inner join
918        let plan = LogicalPlanBuilder::from(plan1)
919            .cross_join(plan2)?
920            .filter(binary_expr(
921                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
922                Or,
923                binary_expr(
924                    col("t3.a").eq(col("t1.a")),
925                    And,
926                    col("t4.c").eq(lit(688u32)),
927                ),
928            ))?
929            .build()?;
930
931        assert_optimized_plan_equal!(
932            plan,
933            @ r"
934        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
935          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
936            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
937              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
938                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
939                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
940            Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
941              Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
942                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
943                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
944        "
945        )
946    }
947
948    #[test]
949    fn eliminate_cross_join_multi_tables_3() -> Result<()> {
950        let t1 = test_table_scan_with_name("t1")?;
951        let t2 = test_table_scan_with_name("t2")?;
952        let t3 = test_table_scan_with_name("t3")?;
953        let t4 = test_table_scan_with_name("t4")?;
954
955        // could not eliminate to inner join
956        let plan1 = LogicalPlanBuilder::from(t1)
957            .cross_join(t2)?
958            .filter(binary_expr(
959                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
960                Or,
961                binary_expr(
962                    col("t1.a").eq(col("t2.a")),
963                    And,
964                    col("t2.c").eq(lit(688u32)),
965                ),
966            ))?
967            .build()?;
968
969        // could eliminate to inner join
970        let plan2 = LogicalPlanBuilder::from(t3)
971            .cross_join(t4)?
972            .filter(binary_expr(
973                binary_expr(
974                    binary_expr(
975                        col("t3.a").eq(col("t4.a")),
976                        And,
977                        col("t4.c").lt(lit(15u32)),
978                    ),
979                    Or,
980                    binary_expr(
981                        col("t3.a").eq(col("t4.a")),
982                        And,
983                        col("t3.c").eq(lit(688u32)),
984                    ),
985                ),
986                Or,
987                binary_expr(
988                    col("t3.a").eq(col("t4.a")),
989                    And,
990                    col("t3.b").eq(col("t4.b")),
991                ),
992            ))?
993            .build()?;
994
995        // could eliminate to inner join
996        let plan = LogicalPlanBuilder::from(plan1)
997            .cross_join(plan2)?
998            .filter(binary_expr(
999                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
1000                Or,
1001                binary_expr(
1002                    col("t3.a").eq(col("t1.a")),
1003                    And,
1004                    col("t4.c").eq(lit(688u32)),
1005                ),
1006            ))?
1007            .build()?;
1008
1009        assert_optimized_plan_equal!(
1010            plan,
1011            @ r"
1012        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1013          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1014            Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1015              Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1016                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1017                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1018            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1019              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1020                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1021                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1022        "
1023        )
1024    }
1025
1026    #[test]
1027    fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1028        let t1 = test_table_scan_with_name("t1")?;
1029        let t2 = test_table_scan_with_name("t2")?;
1030        let t3 = test_table_scan_with_name("t3")?;
1031        let t4 = test_table_scan_with_name("t4")?;
1032
1033        // could eliminate to inner join
1034        // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
1035        let plan1 = LogicalPlanBuilder::from(t1)
1036            .cross_join(t2)?
1037            .filter(binary_expr(
1038                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1039                And,
1040                binary_expr(
1041                    col("t1.a").eq(col("t2.a")),
1042                    And,
1043                    col("t2.c").eq(lit(688u32)),
1044                ),
1045            ))?
1046            .build()?;
1047
1048        // could eliminate to inner join
1049        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1050
1051        // could eliminate to inner join
1052        // filter:
1053        //   ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1054        //     AND
1055        //   ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1056        let plan = LogicalPlanBuilder::from(plan1)
1057            .cross_join(plan2)?
1058            .filter(binary_expr(
1059                binary_expr(
1060                    binary_expr(
1061                        col("t3.a").eq(col("t1.a")),
1062                        And,
1063                        col("t4.c").lt(lit(15u32)),
1064                    ),
1065                    Or,
1066                    binary_expr(
1067                        col("t3.a").eq(col("t1.a")),
1068                        And,
1069                        col("t4.c").eq(lit(688u32)),
1070                    ),
1071                ),
1072                And,
1073                binary_expr(
1074                    binary_expr(
1075                        binary_expr(
1076                            col("t3.a").eq(col("t4.a")),
1077                            And,
1078                            col("t4.c").lt(lit(15u32)),
1079                        ),
1080                        Or,
1081                        binary_expr(
1082                            col("t3.a").eq(col("t4.a")),
1083                            And,
1084                            col("t3.c").eq(lit(688u32)),
1085                        ),
1086                    ),
1087                    Or,
1088                    binary_expr(
1089                        col("t3.a").eq(col("t4.a")),
1090                        And,
1091                        col("t3.b").eq(col("t4.b")),
1092                    ),
1093                ),
1094            ))?
1095            .build()?;
1096
1097        assert_optimized_plan_equal!(
1098            plan,
1099            @ r"
1100        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1101          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1102            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1103              Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1104                Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1105                  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1106                  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1107              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1108            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1109        "
1110        )
1111    }
1112
1113    #[test]
1114    fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1115        let t1 = test_table_scan_with_name("t1")?;
1116        let t2 = test_table_scan_with_name("t2")?;
1117        let t3 = test_table_scan_with_name("t3")?;
1118        let t4 = test_table_scan_with_name("t4")?;
1119
1120        // could eliminate to inner join
1121        let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1122
1123        // could eliminate to inner join
1124        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1125
1126        // could eliminate to inner join
1127        // Filter:
1128        //  ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1129        //      AND
1130        //  ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1131        //      AND
1132        //  ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
1133        let plan = LogicalPlanBuilder::from(plan1)
1134            .cross_join(plan2)?
1135            .filter(binary_expr(
1136                binary_expr(
1137                    binary_expr(
1138                        binary_expr(
1139                            col("t3.a").eq(col("t1.a")),
1140                            And,
1141                            col("t4.c").lt(lit(15u32)),
1142                        ),
1143                        Or,
1144                        binary_expr(
1145                            col("t3.a").eq(col("t1.a")),
1146                            And,
1147                            col("t4.c").eq(lit(688u32)),
1148                        ),
1149                    ),
1150                    And,
1151                    binary_expr(
1152                        binary_expr(
1153                            binary_expr(
1154                                col("t3.a").eq(col("t4.a")),
1155                                And,
1156                                col("t4.c").lt(lit(15u32)),
1157                            ),
1158                            Or,
1159                            binary_expr(
1160                                col("t3.a").eq(col("t4.a")),
1161                                And,
1162                                col("t3.c").eq(lit(688u32)),
1163                            ),
1164                        ),
1165                        Or,
1166                        binary_expr(
1167                            col("t3.a").eq(col("t4.a")),
1168                            And,
1169                            col("t3.b").eq(col("t4.b")),
1170                        ),
1171                    ),
1172                ),
1173                And,
1174                binary_expr(
1175                    binary_expr(
1176                        col("t1.a").eq(col("t2.a")),
1177                        Or,
1178                        col("t2.c").lt(lit(15u32)),
1179                    ),
1180                    And,
1181                    binary_expr(
1182                        col("t1.a").eq(col("t2.a")),
1183                        And,
1184                        col("t2.c").eq(lit(688u32)),
1185                    ),
1186                ),
1187            ))?
1188            .build()?;
1189
1190        assert_optimized_plan_equal!(
1191            plan,
1192            @ r"
1193        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1194          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1195            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1196              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1197                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1198                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1199              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1200            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1201        "
1202        )
1203    }
1204
1205    #[test]
1206    fn eliminate_cross_join_with_expr_and() -> Result<()> {
1207        let t1 = test_table_scan_with_name("t1")?;
1208        let t2 = test_table_scan_with_name("t2")?;
1209
1210        // could eliminate to inner join since filter has Join predicates
1211        let plan = LogicalPlanBuilder::from(t1)
1212            .cross_join(t2)?
1213            .filter(binary_expr(
1214                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1215                And,
1216                col("t2.c").lt(lit(20u32)),
1217            ))?
1218            .build()?;
1219
1220        assert_optimized_plan_equal!(
1221            plan,
1222            @ r"
1223        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1224          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1225            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1226            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1227        "
1228        )
1229    }
1230
1231    #[test]
1232    fn eliminate_cross_with_expr_or() -> Result<()> {
1233        let t1 = test_table_scan_with_name("t1")?;
1234        let t2 = test_table_scan_with_name("t2")?;
1235
1236        // could not eliminate to inner join since filter OR expression and there is no common
1237        // Join predicates in left and right of OR expr.
1238        let plan = LogicalPlanBuilder::from(t1)
1239            .cross_join(t2)?
1240            .filter(binary_expr(
1241                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1242                Or,
1243                col("t2.b").eq(col("t1.a")),
1244            ))?
1245            .build()?;
1246
1247        assert_optimized_plan_equal!(
1248            plan,
1249            @ r"
1250        Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1251          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1252            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1253            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1254        "
1255        )
1256    }
1257
1258    #[test]
1259    fn eliminate_cross_with_common_expr_and() -> Result<()> {
1260        let t1 = test_table_scan_with_name("t1")?;
1261        let t2 = test_table_scan_with_name("t2")?;
1262
1263        // could eliminate to inner join
1264        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1265        let plan = LogicalPlanBuilder::from(t1)
1266            .cross_join(t2)?
1267            .filter(binary_expr(
1268                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1269                And,
1270                binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1271            ))?
1272            .build()?;
1273
1274        assert_optimized_plan_equal!(
1275            plan,
1276            @ r"
1277        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1278          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1279            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1280            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1281        "
1282        )
1283    }
1284
1285    #[test]
1286    fn eliminate_cross_with_common_expr_or() -> Result<()> {
1287        let t1 = test_table_scan_with_name("t1")?;
1288        let t2 = test_table_scan_with_name("t2")?;
1289
1290        // could eliminate to inner join since Or predicates have common Join predicates
1291        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1292        let plan = LogicalPlanBuilder::from(t1)
1293            .cross_join(t2)?
1294            .filter(binary_expr(
1295                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1296                Or,
1297                binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1298            ))?
1299            .build()?;
1300
1301        assert_optimized_plan_equal!(
1302            plan,
1303            @ r"
1304        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1305          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1306            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1307            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1308        "
1309        )
1310    }
1311
1312    #[test]
1313    fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1314        let t1 = test_table_scan_with_name("t1")?;
1315        let t2 = test_table_scan_with_name("t2")?;
1316        let t3 = test_table_scan_with_name("t3")?;
1317
1318        // could eliminate to inner join
1319        let plan = LogicalPlanBuilder::from(t1)
1320            .cross_join(t2)?
1321            .cross_join(t3)?
1322            .filter(binary_expr(
1323                binary_expr(
1324                    (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1325                    And,
1326                    col("t3.c").lt(lit(15u32)),
1327                ),
1328                And,
1329                binary_expr(
1330                    (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1331                    And,
1332                    col("t3.b").lt(lit(15u32)),
1333                ),
1334            ))?
1335            .build()?;
1336
1337        assert_optimized_plan_equal!(
1338            plan,
1339            @ r"
1340        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1341          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1342            Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1343              Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1344                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1345                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1346              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1347        "
1348        )
1349    }
1350
1351    #[test]
1352    fn preserve_null_equality_setting() -> Result<()> {
1353        let t1 = test_table_scan_with_name("t1")?;
1354        let t2 = test_table_scan_with_name("t2")?;
1355
1356        // Create an inner join with NullEquality::NullEqualsNull
1357        let join_schema = Arc::new(build_join_schema(
1358            t1.schema(),
1359            t2.schema(),
1360            &JoinType::Inner,
1361        )?);
1362
1363        let inner_join = LogicalPlan::Join(Join {
1364            left: Arc::new(t1),
1365            right: Arc::new(t2),
1366            join_type: JoinType::Inner,
1367            join_constraint: JoinConstraint::On,
1368            on: vec![],
1369            filter: None,
1370            schema: join_schema,
1371            null_equality: NullEquality::NullEqualsNull, // Test preservation
1372            null_aware: false,
1373        });
1374
1375        // Apply filter that can create join conditions
1376        let plan = LogicalPlanBuilder::from(inner_join)
1377            .filter(binary_expr(
1378                col("t1.a").eq(col("t2.a")),
1379                And,
1380                col("t2.c").lt(lit(20u32)),
1381            ))?
1382            .build()?;
1383
1384        let rule = EliminateCrossJoin::new();
1385        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1386
1387        // Verify that null_equality is preserved in the optimized plan
1388        fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
1389            match plan {
1390                LogicalPlan::Join(join) => {
1391                    // All joins in the optimized plan should preserve null equality
1392                    if join.null_equality == NullEquality::NullEqualsNothing {
1393                        return false;
1394                    }
1395                    // Recursively check child plans
1396                    plan.inputs()
1397                        .iter()
1398                        .all(|input| check_null_equality_preserved(input))
1399                }
1400                _ => {
1401                    // Recursively check child plans for non-join nodes
1402                    plan.inputs()
1403                        .iter()
1404                        .all(|input| check_null_equality_preserved(input))
1405                }
1406            }
1407        }
1408
1409        assert!(
1410            check_null_equality_preserved(&optimized_plan),
1411            "null_equality setting should be preserved after optimization"
1412        );
1413
1414        Ok(())
1415    }
1416}