Skip to main content

datafusion_optimizer/
eliminate_cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27    Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{ExprSchemable, Operator, and, build_join_schema};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36    #[expect(missing_docs)]
37    pub fn new() -> Self {
38        Self {}
39    }
40}
41
42/// Eliminate cross joins by rewriting them to inner joins when possible.
43///
44/// # Example
45/// The initial plan for this query:
46/// ```sql
47/// select ... from a, b where a.x = b.y and b.xx = 100;
48/// ```
49///
50/// Looks like this:
51/// ```text
52/// Filter(a.x = b.y AND b.xx = 100)
53///  Cross Join
54///   TableScan a
55///   TableScan b
56/// ```
57///
58/// After the rule is applied, the plan will look like this:
59/// ```text
60/// Filter(b.xx = 100)
61///   InnerJoin(a.x = b.y)
62///     TableScan a
63///     TableScan b
64/// ```
65///
66/// # Other Examples
67/// * 'select ... from a, b where a.x = b.y and b.xx = 100;'
68/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
69/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
70/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);'
71/// * 'select ... from a, b where a.x > b.y'
72///
73/// For above queries, the join predicate is available in filters and they are moved to
74/// join nodes appropriately
75///
76/// This fix helps to improve the performance of TPCH Q19. issue#78
77impl OptimizerRule for EliminateCrossJoin {
78    fn supports_rewrite(&self) -> bool {
79        true
80    }
81
82    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83    fn rewrite(
84        &self,
85        plan: LogicalPlan,
86        config: &dyn OptimizerConfig,
87    ) -> Result<Transformed<LogicalPlan>> {
88        let plan_schema = Arc::clone(plan.schema());
89        let mut possible_join_keys = JoinKeySet::new();
90        let mut all_inputs: Vec<LogicalPlan> = vec![];
91        let mut all_filters: Vec<Expr> = vec![];
92        let mut null_equality = NullEquality::NullEqualsNothing;
93
94        let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
95            // if input isn't a join that can potentially be rewritten
96            // avoid unwrapping the input
97            let rewritable = matches!(
98                filter.input.as_ref(),
99                LogicalPlan::Join(Join {
100                    join_type: JoinType::Inner,
101                    ..
102                })
103            );
104
105            if !rewritable {
106                // recursively try to rewrite children
107                return rewrite_children(self, LogicalPlan::Filter(filter), config);
108            }
109
110            if !can_flatten_join_inputs(&filter.input) {
111                return Ok(Transformed::no(LogicalPlan::Filter(filter)));
112            }
113
114            let Filter {
115                input, predicate, ..
116            } = filter;
117
118            // Extract null_equality setting from the input join
119            if let LogicalPlan::Join(join) = input.as_ref() {
120                null_equality = join.null_equality;
121            }
122
123            flatten_join_inputs(
124                Arc::unwrap_or_clone(input),
125                &mut possible_join_keys,
126                &mut all_inputs,
127                &mut all_filters,
128            )?;
129
130            extract_possible_join_keys(&predicate, &mut possible_join_keys);
131            Some(predicate)
132        } else {
133            match plan {
134                LogicalPlan::Join(Join {
135                    join_type: JoinType::Inner,
136                    null_equality: original_null_equality,
137                    ..
138                }) => {
139                    if !can_flatten_join_inputs(&plan) {
140                        return Ok(Transformed::no(plan));
141                    }
142                    flatten_join_inputs(
143                        plan,
144                        &mut possible_join_keys,
145                        &mut all_inputs,
146                        &mut all_filters,
147                    )?;
148                    null_equality = original_null_equality;
149                    None
150                }
151                _ => {
152                    // recursively try to rewrite children
153                    return rewrite_children(self, plan, config);
154                }
155            }
156        };
157
158        // Join keys are handled locally:
159        let mut all_join_keys = JoinKeySet::new();
160        let mut left = all_inputs.remove(0);
161        while !all_inputs.is_empty() {
162            left = find_inner_join(
163                left,
164                &mut all_inputs,
165                &possible_join_keys,
166                &mut all_join_keys,
167                null_equality,
168            )?;
169        }
170
171        left = rewrite_children(self, left, config)?.data;
172
173        if &plan_schema != left.schema() {
174            left = LogicalPlan::Projection(Projection::new_from_schema(
175                Arc::new(left),
176                Arc::clone(&plan_schema),
177            ));
178        }
179
180        if !all_filters.is_empty() {
181            // Add any filters on top - PushDownFilter can push filters down to applicable join
182            let first = all_filters.swap_remove(0);
183            let predicate = all_filters.into_iter().fold(first, and);
184            left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
185        }
186
187        let Some(predicate) = parent_predicate else {
188            return Ok(Transformed::yes(left));
189        };
190
191        // If there are no join keys then do nothing:
192        if all_join_keys.is_empty() {
193            Filter::try_new(predicate, Arc::new(left))
194                .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
195        } else {
196            // Remove join expressions from filter:
197            match remove_join_expressions(predicate, &all_join_keys) {
198                Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
199                    .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
200                _ => Ok(Transformed::yes(left)),
201            }
202        }
203    }
204
205    fn name(&self) -> &str {
206        "eliminate_cross_join"
207    }
208}
209
210fn rewrite_children(
211    optimizer: &impl OptimizerRule,
212    plan: LogicalPlan,
213    config: &dyn OptimizerConfig,
214) -> Result<Transformed<LogicalPlan>> {
215    // Process uncorrelated subqueries in expressions, then direct children.
216    let transformed_plan = plan
217        .map_uncorrelated_subqueries(|input| optimizer.rewrite(input, config))?
218        .transform_sibling(|plan| {
219            plan.map_children(|input| optimizer.rewrite(input, config))
220        })?;
221
222    // recompute schema if the plan was transformed
223    if transformed_plan.transformed {
224        transformed_plan.map_data(|plan| plan.recompute_schema())
225    } else {
226        Ok(transformed_plan)
227    }
228}
229
230/// Recursively accumulate possible_join_keys and inputs from inner joins
231/// (including cross joins).
232///
233/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
234/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
235/// possible_join_keys
236fn flatten_join_inputs(
237    plan: LogicalPlan,
238    possible_join_keys: &mut JoinKeySet,
239    all_inputs: &mut Vec<LogicalPlan>,
240    all_filters: &mut Vec<Expr>,
241) -> Result<()> {
242    match plan {
243        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
244            if let Some(filter) = join.filter {
245                all_filters.push(filter);
246            }
247            possible_join_keys.insert_all_owned(join.on);
248            flatten_join_inputs(
249                Arc::unwrap_or_clone(join.left),
250                possible_join_keys,
251                all_inputs,
252                all_filters,
253            )?;
254            flatten_join_inputs(
255                Arc::unwrap_or_clone(join.right),
256                possible_join_keys,
257                all_inputs,
258                all_filters,
259            )?;
260        }
261        _ => {
262            all_inputs.push(plan);
263        }
264    };
265    Ok(())
266}
267
268/// Returns true if the plan is a Join or Cross join could be flattened with
269/// `flatten_join_inputs`
270///
271/// Must stay in sync with `flatten_join_inputs`
272fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
273    // can only flatten inner / cross joins
274    match plan {
275        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
276        _ => return false,
277    };
278
279    for child in plan.inputs() {
280        if let LogicalPlan::Join(Join {
281            join_type: JoinType::Inner,
282            ..
283        }) = child
284            && !can_flatten_join_inputs(child)
285        {
286            return false;
287        }
288    }
289    true
290}
291
292/// Finds the next to join with the left input plan,
293///
294/// Finds the next `right` from `rights` that can be joined with `left_input`
295/// plan based on the join keys in `possible_join_keys`.
296///
297/// If such a matching `right` is found:
298/// 1. Adds the matching join keys to `all_join_keys`.
299/// 2. Returns `left_input JOIN right ON (all join keys)`.
300///
301/// If no matching `right` is found:
302/// 1. Removes the first plan from `rights`
303/// 2. Returns `left_input CROSS JOIN right`.
304fn find_inner_join(
305    left_input: LogicalPlan,
306    rights: &mut Vec<LogicalPlan>,
307    possible_join_keys: &JoinKeySet,
308    all_join_keys: &mut JoinKeySet,
309    null_equality: NullEquality,
310) -> Result<LogicalPlan> {
311    for (i, right_input) in rights.iter().enumerate() {
312        let mut join_keys = vec![];
313
314        for (l, r) in possible_join_keys.iter() {
315            let key_pair = find_valid_equijoin_key_pair(
316                l,
317                r,
318                left_input.schema(),
319                right_input.schema(),
320            )?;
321
322            // Save join keys
323            if let Some((valid_l, valid_r)) = key_pair
324                && can_hash(&valid_l.get_type(left_input.schema())?)
325            {
326                join_keys.push((valid_l, valid_r));
327            }
328        }
329
330        // Found one or more matching join keys
331        if !join_keys.is_empty() {
332            all_join_keys.insert_all(join_keys.iter());
333            let right_input = rights.remove(i);
334            let join_schema = Arc::new(build_join_schema(
335                left_input.schema(),
336                right_input.schema(),
337                &JoinType::Inner,
338            )?);
339
340            return Ok(LogicalPlan::Join(Join {
341                left: Arc::new(left_input),
342                right: Arc::new(right_input),
343                join_type: JoinType::Inner,
344                join_constraint: JoinConstraint::On,
345                on: join_keys,
346                filter: None,
347                schema: join_schema,
348                null_equality,
349                null_aware: false,
350            }));
351        }
352    }
353
354    // no matching right plan had any join keys, cross join with the first right
355    // plan
356    let right = rights.remove(0);
357    let join_schema = Arc::new(build_join_schema(
358        left_input.schema(),
359        right.schema(),
360        &JoinType::Inner,
361    )?);
362
363    Ok(LogicalPlan::Join(Join {
364        left: Arc::new(left_input),
365        right: Arc::new(right),
366        schema: join_schema,
367        on: vec![],
368        filter: None,
369        join_type: JoinType::Inner,
370        join_constraint: JoinConstraint::On,
371        null_equality,
372        null_aware: false,
373    }))
374}
375
376/// Extract join keys from a WHERE clause
377fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
378    if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
379        match op {
380            Operator::Eq => {
381                // insert handles ensuring  we don't add the same Join keys multiple times
382                join_keys.insert(left, right);
383            }
384            Operator::And => {
385                extract_possible_join_keys(left, join_keys);
386                extract_possible_join_keys(right, join_keys)
387            }
388            // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
389            Operator::Or => {
390                let mut left_join_keys = JoinKeySet::new();
391                let mut right_join_keys = JoinKeySet::new();
392
393                extract_possible_join_keys(left, &mut left_join_keys);
394                extract_possible_join_keys(right, &mut right_join_keys);
395
396                join_keys.insert_intersection(&left_join_keys, &right_join_keys)
397            }
398            _ => (),
399        };
400    }
401}
402
403/// Remove join expressions from a filter expression
404///
405/// # Returns
406/// * `Some()` when there are few remaining predicates in filter_expr
407/// * `None` otherwise
408fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
409    match expr {
410        Expr::BinaryExpr(BinaryExpr {
411            left,
412            op: Operator::Eq,
413            right,
414        }) if join_keys.contains(&left, &right) => {
415            // was a join key, so remove it
416            None
417        }
418        // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
419        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
420            let l = remove_join_expressions(*left, join_keys);
421            let r = remove_join_expressions(*right, join_keys);
422            match (l, r) {
423                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
424                    Box::new(ll),
425                    op,
426                    Box::new(rr),
427                ))),
428                (Some(ll), _) => Some(ll),
429                (_, Some(rr)) => Some(rr),
430                _ => None,
431            }
432        }
433        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
434            let l = remove_join_expressions(*left, join_keys);
435            let r = remove_join_expressions(*right, join_keys);
436            match (l, r) {
437                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
438                    Box::new(ll),
439                    op,
440                    Box::new(rr),
441                ))),
442                // When either `left` or `right` is empty, it means they are `true`
443                // so OR'ing anything with them will also be true
444                _ => None,
445            }
446        }
447        _ => Some(expr),
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::optimizer::OptimizerContext;
455    use crate::test::*;
456
457    use datafusion_expr::{
458        Operator::{And, Or},
459        binary_expr, col, lit,
460        logical_plan::builder::LogicalPlanBuilder,
461    };
462    use insta::assert_snapshot;
463
464    macro_rules! assert_optimized_plan_equal {
465        (
466            $plan:expr,
467            @ $expected:literal $(,)?
468        ) => {{
469            let starting_schema = Arc::clone($plan.schema());
470            let rule = EliminateCrossJoin::new();
471            let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
472            let formatted_plan = optimized_plan.display_indent_schema();
473            // Ensure the rule was actually applied
474            assert!(is_plan_transformed, "failed to optimize plan");
475            // Verify the schema remains unchanged
476            assert_eq!(&starting_schema, optimized_plan.schema());
477            assert_snapshot!(
478                formatted_plan,
479                @ $expected,
480            );
481
482            Ok(())
483        }};
484    }
485
486    #[test]
487    fn eliminate_cross_with_simple_and() -> Result<()> {
488        let t1 = test_table_scan_with_name("t1")?;
489        let t2 = test_table_scan_with_name("t2")?;
490
491        // could eliminate to inner join since filter has Join predicates
492        let plan = LogicalPlanBuilder::from(t1)
493            .cross_join(t2)?
494            .filter(binary_expr(
495                col("t1.a").eq(col("t2.a")),
496                And,
497                col("t2.c").lt(lit(20u32)),
498            ))?
499            .build()?;
500
501        assert_optimized_plan_equal!(
502            plan,
503            @ r"
504        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
505          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
506            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
507            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
508        "
509        )
510    }
511
512    #[test]
513    fn eliminate_cross_with_simple_or() -> Result<()> {
514        let t1 = test_table_scan_with_name("t1")?;
515        let t2 = test_table_scan_with_name("t2")?;
516
517        // could not eliminate to inner join since filter OR expression and there is no common
518        // Join predicates in left and right of OR expr.
519        let plan = LogicalPlanBuilder::from(t1)
520            .cross_join(t2)?
521            .filter(binary_expr(
522                col("t1.a").eq(col("t2.a")),
523                Or,
524                col("t2.b").eq(col("t1.a")),
525            ))?
526            .build()?;
527
528        assert_optimized_plan_equal!(
529            plan,
530            @ r"
531        Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
532          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
533            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
534            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
535        "
536        )
537    }
538
539    #[test]
540    fn eliminate_cross_with_and() -> Result<()> {
541        let t1 = test_table_scan_with_name("t1")?;
542        let t2 = test_table_scan_with_name("t2")?;
543
544        // could eliminate to inner join
545        let plan = LogicalPlanBuilder::from(t1)
546            .cross_join(t2)?
547            .filter(binary_expr(
548                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
549                And,
550                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
551            ))?
552            .build()?;
553
554        assert_optimized_plan_equal!(
555            plan,
556            @ r"
557        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
558          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
559            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
560            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
561        "
562        )
563    }
564
565    #[test]
566    fn eliminate_cross_with_or() -> Result<()> {
567        let t1 = test_table_scan_with_name("t1")?;
568        let t2 = test_table_scan_with_name("t2")?;
569
570        // could eliminate to inner join since Or predicates have common Join predicates
571        let plan = LogicalPlanBuilder::from(t1)
572            .cross_join(t2)?
573            .filter(binary_expr(
574                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
575                Or,
576                binary_expr(
577                    col("t1.a").eq(col("t2.a")),
578                    And,
579                    col("t2.c").eq(lit(688u32)),
580                ),
581            ))?
582            .build()?;
583
584        assert_optimized_plan_equal!(
585            plan,
586            @ r"
587        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
588          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
589            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
590            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
591        "
592        )
593    }
594
595    #[test]
596    fn eliminate_cross_not_possible_simple() -> Result<()> {
597        let t1 = test_table_scan_with_name("t1")?;
598        let t2 = test_table_scan_with_name("t2")?;
599
600        // could not eliminate to inner join
601        let plan = LogicalPlanBuilder::from(t1)
602            .cross_join(t2)?
603            .filter(binary_expr(
604                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
605                Or,
606                binary_expr(
607                    col("t1.b").eq(col("t2.b")),
608                    And,
609                    col("t2.c").eq(lit(688u32)),
610                ),
611            ))?
612            .build()?;
613
614        assert_optimized_plan_equal!(
615            plan,
616            @ r"
617        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
618          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
619            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
620            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
621        "
622        )
623    }
624
625    #[test]
626    fn eliminate_cross_not_possible() -> Result<()> {
627        let t1 = test_table_scan_with_name("t1")?;
628        let t2 = test_table_scan_with_name("t2")?;
629
630        // could not eliminate to inner join
631        let plan = LogicalPlanBuilder::from(t1)
632            .cross_join(t2)?
633            .filter(binary_expr(
634                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
635                Or,
636                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
637            ))?
638            .build()?;
639
640        assert_optimized_plan_equal!(
641            plan,
642            @ r"
643        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
644          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
645            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
646            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
647        "
648        )
649    }
650
651    #[test]
652    fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
653        let t1 = test_table_scan_with_name("t1")?;
654        let t2 = test_table_scan_with_name("t2")?;
655        let t3 = test_table_scan_with_name("t3")?;
656
657        // could not eliminate to inner join with filter
658        let plan = LogicalPlanBuilder::from(t1)
659            .join(
660                t3,
661                JoinType::Inner,
662                (vec!["t1.a"], vec!["t3.a"]),
663                Some(col("t1.a").gt(lit(20u32))),
664            )?
665            .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
666            .filter(col("t1.a").gt(lit(15u32)))?
667            .build()?;
668
669        assert_optimized_plan_equal!(
670            plan,
671            @ r"
672        Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
673          Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
674            Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
675              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
676                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
677                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
678              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
679        "
680        )
681    }
682
683    #[test]
684    /// ```txt
685    /// filter: a.id = b.id and a.id = c.id
686    ///   cross_join a (bc)
687    ///     cross_join b c
688    /// ```
689    /// Without reorder, it will be
690    /// ```txt
691    ///   inner_join a (bc) on a.id = b.id and a.id = c.id
692    ///     cross_join b c
693    /// ```
694    /// Reorder it to be
695    /// ```txt
696    ///   inner_join (ab)c and a.id = c.id
697    ///     inner_join a b on a.id = b.id
698    /// ```
699    fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
700        let t1 = test_table_scan_with_name("t1")?;
701        let t2 = test_table_scan_with_name("t2")?;
702        let t3 = test_table_scan_with_name("t3")?;
703
704        // could eliminate to inner join
705        let plan = LogicalPlanBuilder::from(t1)
706            .cross_join(t2)?
707            .cross_join(t3)?
708            .filter(binary_expr(
709                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
710                And,
711                binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
712            ))?
713            .build()?;
714
715        assert_optimized_plan_equal!(
716            plan,
717            @ r"
718        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
719          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
720            Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
721              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
722                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
723                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
724              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
725        "
726        )
727    }
728
729    #[test]
730    fn eliminate_cross_join_multi_tables() -> Result<()> {
731        let t1 = test_table_scan_with_name("t1")?;
732        let t2 = test_table_scan_with_name("t2")?;
733        let t3 = test_table_scan_with_name("t3")?;
734        let t4 = test_table_scan_with_name("t4")?;
735
736        // could eliminate to inner join
737        let plan1 = LogicalPlanBuilder::from(t1)
738            .cross_join(t2)?
739            .filter(binary_expr(
740                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
741                Or,
742                binary_expr(
743                    col("t1.a").eq(col("t2.a")),
744                    And,
745                    col("t2.c").eq(lit(688u32)),
746                ),
747            ))?
748            .build()?;
749
750        let plan2 = LogicalPlanBuilder::from(t3)
751            .cross_join(t4)?
752            .filter(binary_expr(
753                binary_expr(
754                    binary_expr(
755                        col("t3.a").eq(col("t4.a")),
756                        And,
757                        col("t4.c").lt(lit(15u32)),
758                    ),
759                    Or,
760                    binary_expr(
761                        col("t3.a").eq(col("t4.a")),
762                        And,
763                        col("t3.c").eq(lit(688u32)),
764                    ),
765                ),
766                Or,
767                binary_expr(
768                    col("t3.a").eq(col("t4.a")),
769                    And,
770                    col("t3.b").eq(col("t4.b")),
771                ),
772            ))?
773            .build()?;
774
775        let plan = LogicalPlanBuilder::from(plan1)
776            .cross_join(plan2)?
777            .filter(binary_expr(
778                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
779                Or,
780                binary_expr(
781                    col("t3.a").eq(col("t1.a")),
782                    And,
783                    col("t4.c").eq(lit(688u32)),
784                ),
785            ))?
786            .build()?;
787
788        assert_optimized_plan_equal!(
789            plan,
790            @ r"
791        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
792          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
793            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
794              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
795                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
796                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
797            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
798              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
799                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
800                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
801        "
802        )
803    }
804
805    #[test]
806    fn eliminate_cross_join_multi_tables_1() -> Result<()> {
807        let t1 = test_table_scan_with_name("t1")?;
808        let t2 = test_table_scan_with_name("t2")?;
809        let t3 = test_table_scan_with_name("t3")?;
810        let t4 = test_table_scan_with_name("t4")?;
811
812        // could eliminate to inner join
813        let plan1 = LogicalPlanBuilder::from(t1)
814            .cross_join(t2)?
815            .filter(binary_expr(
816                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
817                Or,
818                binary_expr(
819                    col("t1.a").eq(col("t2.a")),
820                    And,
821                    col("t2.c").eq(lit(688u32)),
822                ),
823            ))?
824            .build()?;
825
826        // could eliminate to inner join
827        let plan2 = LogicalPlanBuilder::from(t3)
828            .cross_join(t4)?
829            .filter(binary_expr(
830                binary_expr(
831                    binary_expr(
832                        col("t3.a").eq(col("t4.a")),
833                        And,
834                        col("t4.c").lt(lit(15u32)),
835                    ),
836                    Or,
837                    binary_expr(
838                        col("t3.a").eq(col("t4.a")),
839                        And,
840                        col("t3.c").eq(lit(688u32)),
841                    ),
842                ),
843                Or,
844                binary_expr(
845                    col("t3.a").eq(col("t4.a")),
846                    And,
847                    col("t3.b").eq(col("t4.b")),
848                ),
849            ))?
850            .build()?;
851
852        // could not eliminate to inner join
853        let plan = LogicalPlanBuilder::from(plan1)
854            .cross_join(plan2)?
855            .filter(binary_expr(
856                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
857                Or,
858                binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
859            ))?
860            .build()?;
861
862        assert_optimized_plan_equal!(
863            plan,
864            @ r"
865        Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
866          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
867            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
868              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
869                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
870                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
871            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
872              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
873                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
874                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
875        "
876        )
877    }
878
879    #[test]
880    fn eliminate_cross_join_multi_tables_2() -> Result<()> {
881        let t1 = test_table_scan_with_name("t1")?;
882        let t2 = test_table_scan_with_name("t2")?;
883        let t3 = test_table_scan_with_name("t3")?;
884        let t4 = test_table_scan_with_name("t4")?;
885
886        // could eliminate to inner join
887        let plan1 = LogicalPlanBuilder::from(t1)
888            .cross_join(t2)?
889            .filter(binary_expr(
890                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
891                Or,
892                binary_expr(
893                    col("t1.a").eq(col("t2.a")),
894                    And,
895                    col("t2.c").eq(lit(688u32)),
896                ),
897            ))?
898            .build()?;
899
900        // could not eliminate to inner join
901        let plan2 = LogicalPlanBuilder::from(t3)
902            .cross_join(t4)?
903            .filter(binary_expr(
904                binary_expr(
905                    binary_expr(
906                        col("t3.a").eq(col("t4.a")),
907                        And,
908                        col("t4.c").lt(lit(15u32)),
909                    ),
910                    Or,
911                    binary_expr(
912                        col("t3.a").eq(col("t4.a")),
913                        And,
914                        col("t3.c").eq(lit(688u32)),
915                    ),
916                ),
917                Or,
918                binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
919            ))?
920            .build()?;
921
922        // could eliminate to inner join
923        let plan = LogicalPlanBuilder::from(plan1)
924            .cross_join(plan2)?
925            .filter(binary_expr(
926                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
927                Or,
928                binary_expr(
929                    col("t3.a").eq(col("t1.a")),
930                    And,
931                    col("t4.c").eq(lit(688u32)),
932                ),
933            ))?
934            .build()?;
935
936        assert_optimized_plan_equal!(
937            plan,
938            @ r"
939        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
940          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
941            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
942              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
943                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
944                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
945            Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
946              Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
947                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
948                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
949        "
950        )
951    }
952
953    #[test]
954    fn eliminate_cross_join_multi_tables_3() -> Result<()> {
955        let t1 = test_table_scan_with_name("t1")?;
956        let t2 = test_table_scan_with_name("t2")?;
957        let t3 = test_table_scan_with_name("t3")?;
958        let t4 = test_table_scan_with_name("t4")?;
959
960        // could not eliminate to inner join
961        let plan1 = LogicalPlanBuilder::from(t1)
962            .cross_join(t2)?
963            .filter(binary_expr(
964                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
965                Or,
966                binary_expr(
967                    col("t1.a").eq(col("t2.a")),
968                    And,
969                    col("t2.c").eq(lit(688u32)),
970                ),
971            ))?
972            .build()?;
973
974        // could eliminate to inner join
975        let plan2 = LogicalPlanBuilder::from(t3)
976            .cross_join(t4)?
977            .filter(binary_expr(
978                binary_expr(
979                    binary_expr(
980                        col("t3.a").eq(col("t4.a")),
981                        And,
982                        col("t4.c").lt(lit(15u32)),
983                    ),
984                    Or,
985                    binary_expr(
986                        col("t3.a").eq(col("t4.a")),
987                        And,
988                        col("t3.c").eq(lit(688u32)),
989                    ),
990                ),
991                Or,
992                binary_expr(
993                    col("t3.a").eq(col("t4.a")),
994                    And,
995                    col("t3.b").eq(col("t4.b")),
996                ),
997            ))?
998            .build()?;
999
1000        // could eliminate to inner join
1001        let plan = LogicalPlanBuilder::from(plan1)
1002            .cross_join(plan2)?
1003            .filter(binary_expr(
1004                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
1005                Or,
1006                binary_expr(
1007                    col("t3.a").eq(col("t1.a")),
1008                    And,
1009                    col("t4.c").eq(lit(688u32)),
1010                ),
1011            ))?
1012            .build()?;
1013
1014        assert_optimized_plan_equal!(
1015            plan,
1016            @ r"
1017        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1018          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1019            Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1020              Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1021                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1022                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1023            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1024              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1025                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1026                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1027        "
1028        )
1029    }
1030
1031    #[test]
1032    fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1033        let t1 = test_table_scan_with_name("t1")?;
1034        let t2 = test_table_scan_with_name("t2")?;
1035        let t3 = test_table_scan_with_name("t3")?;
1036        let t4 = test_table_scan_with_name("t4")?;
1037
1038        // could eliminate to inner join
1039        // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
1040        let plan1 = LogicalPlanBuilder::from(t1)
1041            .cross_join(t2)?
1042            .filter(binary_expr(
1043                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1044                And,
1045                binary_expr(
1046                    col("t1.a").eq(col("t2.a")),
1047                    And,
1048                    col("t2.c").eq(lit(688u32)),
1049                ),
1050            ))?
1051            .build()?;
1052
1053        // could eliminate to inner join
1054        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1055
1056        // could eliminate to inner join
1057        // filter:
1058        //   ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1059        //     AND
1060        //   ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1061        let plan = LogicalPlanBuilder::from(plan1)
1062            .cross_join(plan2)?
1063            .filter(binary_expr(
1064                binary_expr(
1065                    binary_expr(
1066                        col("t3.a").eq(col("t1.a")),
1067                        And,
1068                        col("t4.c").lt(lit(15u32)),
1069                    ),
1070                    Or,
1071                    binary_expr(
1072                        col("t3.a").eq(col("t1.a")),
1073                        And,
1074                        col("t4.c").eq(lit(688u32)),
1075                    ),
1076                ),
1077                And,
1078                binary_expr(
1079                    binary_expr(
1080                        binary_expr(
1081                            col("t3.a").eq(col("t4.a")),
1082                            And,
1083                            col("t4.c").lt(lit(15u32)),
1084                        ),
1085                        Or,
1086                        binary_expr(
1087                            col("t3.a").eq(col("t4.a")),
1088                            And,
1089                            col("t3.c").eq(lit(688u32)),
1090                        ),
1091                    ),
1092                    Or,
1093                    binary_expr(
1094                        col("t3.a").eq(col("t4.a")),
1095                        And,
1096                        col("t3.b").eq(col("t4.b")),
1097                    ),
1098                ),
1099            ))?
1100            .build()?;
1101
1102        assert_optimized_plan_equal!(
1103            plan,
1104            @ r"
1105        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1106          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1107            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1108              Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1109                Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1110                  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1111                  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1112              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1113            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1114        "
1115        )
1116    }
1117
1118    #[test]
1119    fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1120        let t1 = test_table_scan_with_name("t1")?;
1121        let t2 = test_table_scan_with_name("t2")?;
1122        let t3 = test_table_scan_with_name("t3")?;
1123        let t4 = test_table_scan_with_name("t4")?;
1124
1125        // could eliminate to inner join
1126        let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1127
1128        // could eliminate to inner join
1129        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1130
1131        // could eliminate to inner join
1132        // Filter:
1133        //  ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1134        //      AND
1135        //  ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1136        //      AND
1137        //  ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
1138        let plan = LogicalPlanBuilder::from(plan1)
1139            .cross_join(plan2)?
1140            .filter(binary_expr(
1141                binary_expr(
1142                    binary_expr(
1143                        binary_expr(
1144                            col("t3.a").eq(col("t1.a")),
1145                            And,
1146                            col("t4.c").lt(lit(15u32)),
1147                        ),
1148                        Or,
1149                        binary_expr(
1150                            col("t3.a").eq(col("t1.a")),
1151                            And,
1152                            col("t4.c").eq(lit(688u32)),
1153                        ),
1154                    ),
1155                    And,
1156                    binary_expr(
1157                        binary_expr(
1158                            binary_expr(
1159                                col("t3.a").eq(col("t4.a")),
1160                                And,
1161                                col("t4.c").lt(lit(15u32)),
1162                            ),
1163                            Or,
1164                            binary_expr(
1165                                col("t3.a").eq(col("t4.a")),
1166                                And,
1167                                col("t3.c").eq(lit(688u32)),
1168                            ),
1169                        ),
1170                        Or,
1171                        binary_expr(
1172                            col("t3.a").eq(col("t4.a")),
1173                            And,
1174                            col("t3.b").eq(col("t4.b")),
1175                        ),
1176                    ),
1177                ),
1178                And,
1179                binary_expr(
1180                    binary_expr(
1181                        col("t1.a").eq(col("t2.a")),
1182                        Or,
1183                        col("t2.c").lt(lit(15u32)),
1184                    ),
1185                    And,
1186                    binary_expr(
1187                        col("t1.a").eq(col("t2.a")),
1188                        And,
1189                        col("t2.c").eq(lit(688u32)),
1190                    ),
1191                ),
1192            ))?
1193            .build()?;
1194
1195        assert_optimized_plan_equal!(
1196            plan,
1197            @ r"
1198        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1199          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1200            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1201              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1202                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1203                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1204              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1205            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1206        "
1207        )
1208    }
1209
1210    #[test]
1211    fn eliminate_cross_join_with_expr_and() -> Result<()> {
1212        let t1 = test_table_scan_with_name("t1")?;
1213        let t2 = test_table_scan_with_name("t2")?;
1214
1215        // could eliminate to inner join since filter has Join predicates
1216        let plan = LogicalPlanBuilder::from(t1)
1217            .cross_join(t2)?
1218            .filter(binary_expr(
1219                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1220                And,
1221                col("t2.c").lt(lit(20u32)),
1222            ))?
1223            .build()?;
1224
1225        assert_optimized_plan_equal!(
1226            plan,
1227            @ r"
1228        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1229          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1230            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1231            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1232        "
1233        )
1234    }
1235
1236    #[test]
1237    fn eliminate_cross_with_expr_or() -> Result<()> {
1238        let t1 = test_table_scan_with_name("t1")?;
1239        let t2 = test_table_scan_with_name("t2")?;
1240
1241        // could not eliminate to inner join since filter OR expression and there is no common
1242        // Join predicates in left and right of OR expr.
1243        let plan = LogicalPlanBuilder::from(t1)
1244            .cross_join(t2)?
1245            .filter(binary_expr(
1246                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1247                Or,
1248                col("t2.b").eq(col("t1.a")),
1249            ))?
1250            .build()?;
1251
1252        assert_optimized_plan_equal!(
1253            plan,
1254            @ r"
1255        Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1256          Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1257            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1258            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1259        "
1260        )
1261    }
1262
1263    #[test]
1264    fn eliminate_cross_with_common_expr_and() -> Result<()> {
1265        let t1 = test_table_scan_with_name("t1")?;
1266        let t2 = test_table_scan_with_name("t2")?;
1267
1268        // could eliminate to inner join
1269        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1270        let plan = LogicalPlanBuilder::from(t1)
1271            .cross_join(t2)?
1272            .filter(binary_expr(
1273                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1274                And,
1275                binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1276            ))?
1277            .build()?;
1278
1279        assert_optimized_plan_equal!(
1280            plan,
1281            @ r"
1282        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1283          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1284            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1285            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1286        "
1287        )
1288    }
1289
1290    #[test]
1291    fn eliminate_cross_with_common_expr_or() -> Result<()> {
1292        let t1 = test_table_scan_with_name("t1")?;
1293        let t2 = test_table_scan_with_name("t2")?;
1294
1295        // could eliminate to inner join since Or predicates have common Join predicates
1296        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1297        let plan = LogicalPlanBuilder::from(t1)
1298            .cross_join(t2)?
1299            .filter(binary_expr(
1300                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1301                Or,
1302                binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1303            ))?
1304            .build()?;
1305
1306        assert_optimized_plan_equal!(
1307            plan,
1308            @ r"
1309        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1310          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1311            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1312            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1313        "
1314        )
1315    }
1316
1317    #[test]
1318    fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1319        let t1 = test_table_scan_with_name("t1")?;
1320        let t2 = test_table_scan_with_name("t2")?;
1321        let t3 = test_table_scan_with_name("t3")?;
1322
1323        // could eliminate to inner join
1324        let plan = LogicalPlanBuilder::from(t1)
1325            .cross_join(t2)?
1326            .cross_join(t3)?
1327            .filter(binary_expr(
1328                binary_expr(
1329                    (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1330                    And,
1331                    col("t3.c").lt(lit(15u32)),
1332                ),
1333                And,
1334                binary_expr(
1335                    (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1336                    And,
1337                    col("t3.b").lt(lit(15u32)),
1338                ),
1339            ))?
1340            .build()?;
1341
1342        assert_optimized_plan_equal!(
1343            plan,
1344            @ r"
1345        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1346          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1347            Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1348              Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1349                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1350                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1351              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1352        "
1353        )
1354    }
1355
1356    #[test]
1357    fn preserve_null_equality_setting() -> Result<()> {
1358        let t1 = test_table_scan_with_name("t1")?;
1359        let t2 = test_table_scan_with_name("t2")?;
1360
1361        // Create an inner join with NullEquality::NullEqualsNull
1362        let join_schema = Arc::new(build_join_schema(
1363            t1.schema(),
1364            t2.schema(),
1365            &JoinType::Inner,
1366        )?);
1367
1368        let inner_join = LogicalPlan::Join(Join {
1369            left: Arc::new(t1),
1370            right: Arc::new(t2),
1371            join_type: JoinType::Inner,
1372            join_constraint: JoinConstraint::On,
1373            on: vec![],
1374            filter: None,
1375            schema: join_schema,
1376            null_equality: NullEquality::NullEqualsNull, // Test preservation
1377            null_aware: false,
1378        });
1379
1380        // Apply filter that can create join conditions
1381        let plan = LogicalPlanBuilder::from(inner_join)
1382            .filter(binary_expr(
1383                col("t1.a").eq(col("t2.a")),
1384                And,
1385                col("t2.c").lt(lit(20u32)),
1386            ))?
1387            .build()?;
1388
1389        let rule = EliminateCrossJoin::new();
1390        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1391
1392        // Verify that null_equality is preserved in the optimized plan
1393        fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
1394            match plan {
1395                LogicalPlan::Join(join) => {
1396                    // All joins in the optimized plan should preserve null equality
1397                    if join.null_equality == NullEquality::NullEqualsNothing {
1398                        return false;
1399                    }
1400                    // Recursively check child plans
1401                    plan.inputs()
1402                        .iter()
1403                        .all(|input| check_null_equality_preserved(input))
1404                }
1405                _ => {
1406                    // Recursively check child plans for non-join nodes
1407                    plan.inputs()
1408                        .iter()
1409                        .all(|input| check_null_equality_preserved(input))
1410                }
1411            }
1412        }
1413
1414        assert!(
1415            check_null_equality_preserved(&optimized_plan),
1416            "null_equality setting should be preserved after optimization"
1417        );
1418
1419        Ok(())
1420    }
1421}