datafusion_optimizer/
eliminate_cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27    Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{ExprSchemable, Operator, and, build_join_schema};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36    #[expect(missing_docs)]
37    pub fn new() -> Self {
38        Self {}
39    }
40}
41
42/// Eliminate cross joins by rewriting them to inner joins when possible.
43///
44/// # Example
45/// The initial plan for this query:
46/// ```sql
47/// select ... from a, b where a.x = b.y and b.xx = 100;
48/// ```
49///
50/// Looks like this:
51/// ```text
52/// Filter(a.x = b.y AND b.xx = 100)
53///  Cross Join
54///   TableScan a
55///   TableScan b
56/// ```
57///
58/// After the rule is applied, the plan will look like this:
59/// ```text
60/// Filter(b.xx = 100)
61///   InnerJoin(a.x = b.y)
62///     TableScan a
63///     TableScan b
64/// ```
65///
66/// # Other Examples
67/// * 'select ... from a, b where a.x = b.y and b.xx = 100;'
68/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
69/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
70/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);'
71/// * 'select ... from a, b where a.x > b.y'
72///
73/// For above queries, the join predicate is available in filters and they are moved to
74/// join nodes appropriately
75///
76/// This fix helps to improve the performance of TPCH Q19. issue#78
77impl OptimizerRule for EliminateCrossJoin {
78    fn supports_rewrite(&self) -> bool {
79        true
80    }
81
82    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83    fn rewrite(
84        &self,
85        plan: LogicalPlan,
86        config: &dyn OptimizerConfig,
87    ) -> Result<Transformed<LogicalPlan>> {
88        let plan_schema = Arc::clone(plan.schema());
89        let mut possible_join_keys = JoinKeySet::new();
90        let mut all_inputs: Vec<LogicalPlan> = vec![];
91        let mut all_filters: Vec<Expr> = vec![];
92        let mut null_equality = NullEquality::NullEqualsNothing;
93
94        let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
95            // if input isn't a join that can potentially be rewritten
96            // avoid unwrapping the input
97            let rewritable = matches!(
98                filter.input.as_ref(),
99                LogicalPlan::Join(Join {
100                    join_type: JoinType::Inner,
101                    ..
102                })
103            );
104
105            if !rewritable {
106                // recursively try to rewrite children
107                return rewrite_children(self, LogicalPlan::Filter(filter), config);
108            }
109
110            if !can_flatten_join_inputs(&filter.input) {
111                return Ok(Transformed::no(LogicalPlan::Filter(filter)));
112            }
113
114            let Filter {
115                input, predicate, ..
116            } = filter;
117
118            // Extract null_equality setting from the input join
119            if let LogicalPlan::Join(join) = input.as_ref() {
120                null_equality = join.null_equality;
121            }
122
123            flatten_join_inputs(
124                Arc::unwrap_or_clone(input),
125                &mut possible_join_keys,
126                &mut all_inputs,
127                &mut all_filters,
128            )?;
129
130            extract_possible_join_keys(&predicate, &mut possible_join_keys);
131            Some(predicate)
132        } else {
133            match plan {
134                LogicalPlan::Join(Join {
135                    join_type: JoinType::Inner,
136                    null_equality: original_null_equality,
137                    ..
138                }) => {
139                    if !can_flatten_join_inputs(&plan) {
140                        return Ok(Transformed::no(plan));
141                    }
142                    flatten_join_inputs(
143                        plan,
144                        &mut possible_join_keys,
145                        &mut all_inputs,
146                        &mut all_filters,
147                    )?;
148                    null_equality = original_null_equality;
149                    None
150                }
151                _ => {
152                    // recursively try to rewrite children
153                    return rewrite_children(self, plan, config);
154                }
155            }
156        };
157
158        // Join keys are handled locally:
159        let mut all_join_keys = JoinKeySet::new();
160        let mut left = all_inputs.remove(0);
161        while !all_inputs.is_empty() {
162            left = find_inner_join(
163                left,
164                &mut all_inputs,
165                &possible_join_keys,
166                &mut all_join_keys,
167                null_equality,
168            )?;
169        }
170
171        left = rewrite_children(self, left, config)?.data;
172
173        if &plan_schema != left.schema() {
174            left = LogicalPlan::Projection(Projection::new_from_schema(
175                Arc::new(left),
176                Arc::clone(&plan_schema),
177            ));
178        }
179
180        if !all_filters.is_empty() {
181            // Add any filters on top - PushDownFilter can push filters down to applicable join
182            let first = all_filters.swap_remove(0);
183            let predicate = all_filters.into_iter().fold(first, and);
184            left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
185        }
186
187        let Some(predicate) = parent_predicate else {
188            return Ok(Transformed::yes(left));
189        };
190
191        // If there are no join keys then do nothing:
192        if all_join_keys.is_empty() {
193            Filter::try_new(predicate, Arc::new(left))
194                .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
195        } else {
196            // Remove join expressions from filter:
197            match remove_join_expressions(predicate, &all_join_keys) {
198                Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
199                    .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
200                _ => Ok(Transformed::yes(left)),
201            }
202        }
203    }
204
205    fn name(&self) -> &str {
206        "eliminate_cross_join"
207    }
208}
209
210fn rewrite_children(
211    optimizer: &impl OptimizerRule,
212    plan: LogicalPlan,
213    config: &dyn OptimizerConfig,
214) -> Result<Transformed<LogicalPlan>> {
215    let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
216
217    // recompute schema if the plan was transformed
218    if transformed_plan.transformed {
219        transformed_plan.map_data(|plan| plan.recompute_schema())
220    } else {
221        Ok(transformed_plan)
222    }
223}
224
225/// Recursively accumulate possible_join_keys and inputs from inner joins
226/// (including cross joins).
227///
228/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
229/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
230/// possible_join_keys
231fn flatten_join_inputs(
232    plan: LogicalPlan,
233    possible_join_keys: &mut JoinKeySet,
234    all_inputs: &mut Vec<LogicalPlan>,
235    all_filters: &mut Vec<Expr>,
236) -> Result<()> {
237    match plan {
238        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
239            if let Some(filter) = join.filter {
240                all_filters.push(filter);
241            }
242            possible_join_keys.insert_all_owned(join.on);
243            flatten_join_inputs(
244                Arc::unwrap_or_clone(join.left),
245                possible_join_keys,
246                all_inputs,
247                all_filters,
248            )?;
249            flatten_join_inputs(
250                Arc::unwrap_or_clone(join.right),
251                possible_join_keys,
252                all_inputs,
253                all_filters,
254            )?;
255        }
256        _ => {
257            all_inputs.push(plan);
258        }
259    };
260    Ok(())
261}
262
263/// Returns true if the plan is a Join or Cross join could be flattened with
264/// `flatten_join_inputs`
265///
266/// Must stay in sync with `flatten_join_inputs`
267fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
268    // can only flatten inner / cross joins
269    match plan {
270        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
271        _ => return false,
272    };
273
274    for child in plan.inputs() {
275        if let LogicalPlan::Join(Join {
276            join_type: JoinType::Inner,
277            ..
278        }) = child
279            && !can_flatten_join_inputs(child)
280        {
281            return false;
282        }
283    }
284    true
285}
286
287/// Finds the next to join with the left input plan,
288///
289/// Finds the next `right` from `rights` that can be joined with `left_input`
290/// plan based on the join keys in `possible_join_keys`.
291///
292/// If such a matching `right` is found:
293/// 1. Adds the matching join keys to `all_join_keys`.
294/// 2. Returns `left_input JOIN right ON (all join keys)`.
295///
296/// If no matching `right` is found:
297/// 1. Removes the first plan from `rights`
298/// 2. Returns `left_input CROSS JOIN right`.
299fn find_inner_join(
300    left_input: LogicalPlan,
301    rights: &mut Vec<LogicalPlan>,
302    possible_join_keys: &JoinKeySet,
303    all_join_keys: &mut JoinKeySet,
304    null_equality: NullEquality,
305) -> Result<LogicalPlan> {
306    for (i, right_input) in rights.iter().enumerate() {
307        let mut join_keys = vec![];
308
309        for (l, r) in possible_join_keys.iter() {
310            let key_pair = find_valid_equijoin_key_pair(
311                l,
312                r,
313                left_input.schema(),
314                right_input.schema(),
315            )?;
316
317            // Save join keys
318            if let Some((valid_l, valid_r)) = key_pair
319                && can_hash(&valid_l.get_type(left_input.schema())?)
320            {
321                join_keys.push((valid_l, valid_r));
322            }
323        }
324
325        // Found one or more matching join keys
326        if !join_keys.is_empty() {
327            all_join_keys.insert_all(join_keys.iter());
328            let right_input = rights.remove(i);
329            let join_schema = Arc::new(build_join_schema(
330                left_input.schema(),
331                right_input.schema(),
332                &JoinType::Inner,
333            )?);
334
335            return Ok(LogicalPlan::Join(Join {
336                left: Arc::new(left_input),
337                right: Arc::new(right_input),
338                join_type: JoinType::Inner,
339                join_constraint: JoinConstraint::On,
340                on: join_keys,
341                filter: None,
342                schema: join_schema,
343                null_equality,
344            }));
345        }
346    }
347
348    // no matching right plan had any join keys, cross join with the first right
349    // plan
350    let right = rights.remove(0);
351    let join_schema = Arc::new(build_join_schema(
352        left_input.schema(),
353        right.schema(),
354        &JoinType::Inner,
355    )?);
356
357    Ok(LogicalPlan::Join(Join {
358        left: Arc::new(left_input),
359        right: Arc::new(right),
360        schema: join_schema,
361        on: vec![],
362        filter: None,
363        join_type: JoinType::Inner,
364        join_constraint: JoinConstraint::On,
365        null_equality,
366    }))
367}
368
369/// Extract join keys from a WHERE clause
370fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
371    if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
372        match op {
373            Operator::Eq => {
374                // insert handles ensuring  we don't add the same Join keys multiple times
375                join_keys.insert(left, right);
376            }
377            Operator::And => {
378                extract_possible_join_keys(left, join_keys);
379                extract_possible_join_keys(right, join_keys)
380            }
381            // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
382            Operator::Or => {
383                let mut left_join_keys = JoinKeySet::new();
384                let mut right_join_keys = JoinKeySet::new();
385
386                extract_possible_join_keys(left, &mut left_join_keys);
387                extract_possible_join_keys(right, &mut right_join_keys);
388
389                join_keys.insert_intersection(&left_join_keys, &right_join_keys)
390            }
391            _ => (),
392        };
393    }
394}
395
396/// Remove join expressions from a filter expression
397///
398/// # Returns
399/// * `Some()` when there are few remaining predicates in filter_expr
400/// * `None` otherwise
401fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
402    match expr {
403        Expr::BinaryExpr(BinaryExpr {
404            left,
405            op: Operator::Eq,
406            right,
407        }) if join_keys.contains(&left, &right) => {
408            // was a join key, so remove it
409            None
410        }
411        // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
412        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
413            let l = remove_join_expressions(*left, join_keys);
414            let r = remove_join_expressions(*right, join_keys);
415            match (l, r) {
416                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
417                    Box::new(ll),
418                    op,
419                    Box::new(rr),
420                ))),
421                (Some(ll), _) => Some(ll),
422                (_, Some(rr)) => Some(rr),
423                _ => None,
424            }
425        }
426        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
427            let l = remove_join_expressions(*left, join_keys);
428            let r = remove_join_expressions(*right, join_keys);
429            match (l, r) {
430                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
431                    Box::new(ll),
432                    op,
433                    Box::new(rr),
434                ))),
435                // When either `left` or `right` is empty, it means they are `true`
436                // so OR'ing anything with them will also be true
437                _ => None,
438            }
439        }
440        _ => Some(expr),
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::optimizer::OptimizerContext;
448    use crate::test::*;
449
450    use datafusion_expr::{
451        Operator::{And, Or},
452        binary_expr, col, lit,
453        logical_plan::builder::LogicalPlanBuilder,
454    };
455    use insta::assert_snapshot;
456
457    macro_rules! assert_optimized_plan_equal {
458        (
459            $plan:expr,
460            @ $expected:literal $(,)?
461        ) => {{
462            let starting_schema = Arc::clone($plan.schema());
463            let rule = EliminateCrossJoin::new();
464            let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
465            let formatted_plan = optimized_plan.display_indent_schema();
466            // Ensure the rule was actually applied
467            assert!(is_plan_transformed, "failed to optimize plan");
468            // Verify the schema remains unchanged
469            assert_eq!(&starting_schema, optimized_plan.schema());
470            assert_snapshot!(
471                formatted_plan,
472                @ $expected,
473            );
474
475            Ok(())
476        }};
477    }
478
479    #[test]
480    fn eliminate_cross_with_simple_and() -> Result<()> {
481        let t1 = test_table_scan_with_name("t1")?;
482        let t2 = test_table_scan_with_name("t2")?;
483
484        // could eliminate to inner join since filter has Join predicates
485        let plan = LogicalPlanBuilder::from(t1)
486            .cross_join(t2)?
487            .filter(binary_expr(
488                col("t1.a").eq(col("t2.a")),
489                And,
490                col("t2.c").lt(lit(20u32)),
491            ))?
492            .build()?;
493
494        assert_optimized_plan_equal!(
495            plan,
496            @ r"
497        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
498          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
499            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
500            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
501        "
502        )
503    }
504
505    #[test]
506    fn eliminate_cross_with_simple_or() -> Result<()> {
507        let t1 = test_table_scan_with_name("t1")?;
508        let t2 = test_table_scan_with_name("t2")?;
509
510        // could not eliminate to inner join since filter OR expression and there is no common
511        // Join predicates in left and right of OR expr.
512        let plan = LogicalPlanBuilder::from(t1)
513            .cross_join(t2)?
514            .filter(binary_expr(
515                col("t1.a").eq(col("t2.a")),
516                Or,
517                col("t2.b").eq(col("t1.a")),
518            ))?
519            .build()?;
520
521        assert_optimized_plan_equal!(
522            plan,
523            @ r"
524        Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
525          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
526            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
527            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
528        "
529        )
530    }
531
532    #[test]
533    fn eliminate_cross_with_and() -> Result<()> {
534        let t1 = test_table_scan_with_name("t1")?;
535        let t2 = test_table_scan_with_name("t2")?;
536
537        // could eliminate to inner join
538        let plan = LogicalPlanBuilder::from(t1)
539            .cross_join(t2)?
540            .filter(binary_expr(
541                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
542                And,
543                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
544            ))?
545            .build()?;
546
547        assert_optimized_plan_equal!(
548            plan,
549            @ r"
550        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
551          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
552            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
553            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
554        "
555        )
556    }
557
558    #[test]
559    fn eliminate_cross_with_or() -> Result<()> {
560        let t1 = test_table_scan_with_name("t1")?;
561        let t2 = test_table_scan_with_name("t2")?;
562
563        // could eliminate to inner join since Or predicates have common Join predicates
564        let plan = LogicalPlanBuilder::from(t1)
565            .cross_join(t2)?
566            .filter(binary_expr(
567                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
568                Or,
569                binary_expr(
570                    col("t1.a").eq(col("t2.a")),
571                    And,
572                    col("t2.c").eq(lit(688u32)),
573                ),
574            ))?
575            .build()?;
576
577        assert_optimized_plan_equal!(
578            plan,
579            @ r"
580        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
581          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
582            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
583            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
584        "
585        )
586    }
587
588    #[test]
589    fn eliminate_cross_not_possible_simple() -> Result<()> {
590        let t1 = test_table_scan_with_name("t1")?;
591        let t2 = test_table_scan_with_name("t2")?;
592
593        // could not eliminate to inner join
594        let plan = LogicalPlanBuilder::from(t1)
595            .cross_join(t2)?
596            .filter(binary_expr(
597                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
598                Or,
599                binary_expr(
600                    col("t1.b").eq(col("t2.b")),
601                    And,
602                    col("t2.c").eq(lit(688u32)),
603                ),
604            ))?
605            .build()?;
606
607        assert_optimized_plan_equal!(
608            plan,
609            @ r"
610        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
611          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
612            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
613            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
614        "
615        )
616    }
617
618    #[test]
619    fn eliminate_cross_not_possible() -> Result<()> {
620        let t1 = test_table_scan_with_name("t1")?;
621        let t2 = test_table_scan_with_name("t2")?;
622
623        // could not eliminate to inner join
624        let plan = LogicalPlanBuilder::from(t1)
625            .cross_join(t2)?
626            .filter(binary_expr(
627                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
628                Or,
629                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
630            ))?
631            .build()?;
632
633        assert_optimized_plan_equal!(
634            plan,
635            @ r"
636        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
637          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
638            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
639            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
640        "
641        )
642    }
643
644    #[test]
645    fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
646        let t1 = test_table_scan_with_name("t1")?;
647        let t2 = test_table_scan_with_name("t2")?;
648        let t3 = test_table_scan_with_name("t3")?;
649
650        // could not eliminate to inner join with filter
651        let plan = LogicalPlanBuilder::from(t1)
652            .join(
653                t3,
654                JoinType::Inner,
655                (vec!["t1.a"], vec!["t3.a"]),
656                Some(col("t1.a").gt(lit(20u32))),
657            )?
658            .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
659            .filter(col("t1.a").gt(lit(15u32)))?
660            .build()?;
661
662        assert_optimized_plan_equal!(
663            plan,
664            @ r"
665        Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
666          Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
667            Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
668              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
669                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
670                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
671              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
672        "
673        )
674    }
675
676    #[test]
677    /// ```txt
678    /// filter: a.id = b.id and a.id = c.id
679    ///   cross_join a (bc)
680    ///     cross_join b c
681    /// ```
682    /// Without reorder, it will be
683    /// ```txt
684    ///   inner_join a (bc) on a.id = b.id and a.id = c.id
685    ///     cross_join b c
686    /// ```
687    /// Reorder it to be
688    /// ```txt
689    ///   inner_join (ab)c and a.id = c.id
690    ///     inner_join a b on a.id = b.id
691    /// ```
692    fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
693        let t1 = test_table_scan_with_name("t1")?;
694        let t2 = test_table_scan_with_name("t2")?;
695        let t3 = test_table_scan_with_name("t3")?;
696
697        // could eliminate to inner join
698        let plan = LogicalPlanBuilder::from(t1)
699            .cross_join(t2)?
700            .cross_join(t3)?
701            .filter(binary_expr(
702                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
703                And,
704                binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
705            ))?
706            .build()?;
707
708        assert_optimized_plan_equal!(
709            plan,
710            @ r"
711        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
712          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
713            Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
714              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
715                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
716                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
717              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
718        "
719        )
720    }
721
722    #[test]
723    fn eliminate_cross_join_multi_tables() -> Result<()> {
724        let t1 = test_table_scan_with_name("t1")?;
725        let t2 = test_table_scan_with_name("t2")?;
726        let t3 = test_table_scan_with_name("t3")?;
727        let t4 = test_table_scan_with_name("t4")?;
728
729        // could eliminate to inner join
730        let plan1 = LogicalPlanBuilder::from(t1)
731            .cross_join(t2)?
732            .filter(binary_expr(
733                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
734                Or,
735                binary_expr(
736                    col("t1.a").eq(col("t2.a")),
737                    And,
738                    col("t2.c").eq(lit(688u32)),
739                ),
740            ))?
741            .build()?;
742
743        let plan2 = LogicalPlanBuilder::from(t3)
744            .cross_join(t4)?
745            .filter(binary_expr(
746                binary_expr(
747                    binary_expr(
748                        col("t3.a").eq(col("t4.a")),
749                        And,
750                        col("t4.c").lt(lit(15u32)),
751                    ),
752                    Or,
753                    binary_expr(
754                        col("t3.a").eq(col("t4.a")),
755                        And,
756                        col("t3.c").eq(lit(688u32)),
757                    ),
758                ),
759                Or,
760                binary_expr(
761                    col("t3.a").eq(col("t4.a")),
762                    And,
763                    col("t3.b").eq(col("t4.b")),
764                ),
765            ))?
766            .build()?;
767
768        let plan = LogicalPlanBuilder::from(plan1)
769            .cross_join(plan2)?
770            .filter(binary_expr(
771                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
772                Or,
773                binary_expr(
774                    col("t3.a").eq(col("t1.a")),
775                    And,
776                    col("t4.c").eq(lit(688u32)),
777                ),
778            ))?
779            .build()?;
780
781        assert_optimized_plan_equal!(
782            plan,
783            @ r"
784        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
785          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
786            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
787              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
788                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
789                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
790            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
791              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
792                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
793                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
794        "
795        )
796    }
797
798    #[test]
799    fn eliminate_cross_join_multi_tables_1() -> Result<()> {
800        let t1 = test_table_scan_with_name("t1")?;
801        let t2 = test_table_scan_with_name("t2")?;
802        let t3 = test_table_scan_with_name("t3")?;
803        let t4 = test_table_scan_with_name("t4")?;
804
805        // could eliminate to inner join
806        let plan1 = LogicalPlanBuilder::from(t1)
807            .cross_join(t2)?
808            .filter(binary_expr(
809                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
810                Or,
811                binary_expr(
812                    col("t1.a").eq(col("t2.a")),
813                    And,
814                    col("t2.c").eq(lit(688u32)),
815                ),
816            ))?
817            .build()?;
818
819        // could eliminate to inner join
820        let plan2 = LogicalPlanBuilder::from(t3)
821            .cross_join(t4)?
822            .filter(binary_expr(
823                binary_expr(
824                    binary_expr(
825                        col("t3.a").eq(col("t4.a")),
826                        And,
827                        col("t4.c").lt(lit(15u32)),
828                    ),
829                    Or,
830                    binary_expr(
831                        col("t3.a").eq(col("t4.a")),
832                        And,
833                        col("t3.c").eq(lit(688u32)),
834                    ),
835                ),
836                Or,
837                binary_expr(
838                    col("t3.a").eq(col("t4.a")),
839                    And,
840                    col("t3.b").eq(col("t4.b")),
841                ),
842            ))?
843            .build()?;
844
845        // could not eliminate to inner join
846        let plan = LogicalPlanBuilder::from(plan1)
847            .cross_join(plan2)?
848            .filter(binary_expr(
849                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
850                Or,
851                binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
852            ))?
853            .build()?;
854
855        assert_optimized_plan_equal!(
856            plan,
857            @ r"
858        Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
859          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
860            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
861              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
862                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
863                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
864            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
865              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
866                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
867                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
868        "
869        )
870    }
871
872    #[test]
873    fn eliminate_cross_join_multi_tables_2() -> Result<()> {
874        let t1 = test_table_scan_with_name("t1")?;
875        let t2 = test_table_scan_with_name("t2")?;
876        let t3 = test_table_scan_with_name("t3")?;
877        let t4 = test_table_scan_with_name("t4")?;
878
879        // could eliminate to inner join
880        let plan1 = LogicalPlanBuilder::from(t1)
881            .cross_join(t2)?
882            .filter(binary_expr(
883                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
884                Or,
885                binary_expr(
886                    col("t1.a").eq(col("t2.a")),
887                    And,
888                    col("t2.c").eq(lit(688u32)),
889                ),
890            ))?
891            .build()?;
892
893        // could not eliminate to inner join
894        let plan2 = LogicalPlanBuilder::from(t3)
895            .cross_join(t4)?
896            .filter(binary_expr(
897                binary_expr(
898                    binary_expr(
899                        col("t3.a").eq(col("t4.a")),
900                        And,
901                        col("t4.c").lt(lit(15u32)),
902                    ),
903                    Or,
904                    binary_expr(
905                        col("t3.a").eq(col("t4.a")),
906                        And,
907                        col("t3.c").eq(lit(688u32)),
908                    ),
909                ),
910                Or,
911                binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
912            ))?
913            .build()?;
914
915        // could eliminate to inner join
916        let plan = LogicalPlanBuilder::from(plan1)
917            .cross_join(plan2)?
918            .filter(binary_expr(
919                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
920                Or,
921                binary_expr(
922                    col("t3.a").eq(col("t1.a")),
923                    And,
924                    col("t4.c").eq(lit(688u32)),
925                ),
926            ))?
927            .build()?;
928
929        assert_optimized_plan_equal!(
930            plan,
931            @ r"
932        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
933          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
934            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
935              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
936                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
937                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
938            Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
939              Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
940                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
941                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
942        "
943        )
944    }
945
946    #[test]
947    fn eliminate_cross_join_multi_tables_3() -> Result<()> {
948        let t1 = test_table_scan_with_name("t1")?;
949        let t2 = test_table_scan_with_name("t2")?;
950        let t3 = test_table_scan_with_name("t3")?;
951        let t4 = test_table_scan_with_name("t4")?;
952
953        // could not eliminate to inner join
954        let plan1 = LogicalPlanBuilder::from(t1)
955            .cross_join(t2)?
956            .filter(binary_expr(
957                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
958                Or,
959                binary_expr(
960                    col("t1.a").eq(col("t2.a")),
961                    And,
962                    col("t2.c").eq(lit(688u32)),
963                ),
964            ))?
965            .build()?;
966
967        // could eliminate to inner join
968        let plan2 = LogicalPlanBuilder::from(t3)
969            .cross_join(t4)?
970            .filter(binary_expr(
971                binary_expr(
972                    binary_expr(
973                        col("t3.a").eq(col("t4.a")),
974                        And,
975                        col("t4.c").lt(lit(15u32)),
976                    ),
977                    Or,
978                    binary_expr(
979                        col("t3.a").eq(col("t4.a")),
980                        And,
981                        col("t3.c").eq(lit(688u32)),
982                    ),
983                ),
984                Or,
985                binary_expr(
986                    col("t3.a").eq(col("t4.a")),
987                    And,
988                    col("t3.b").eq(col("t4.b")),
989                ),
990            ))?
991            .build()?;
992
993        // could eliminate to inner join
994        let plan = LogicalPlanBuilder::from(plan1)
995            .cross_join(plan2)?
996            .filter(binary_expr(
997                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
998                Or,
999                binary_expr(
1000                    col("t3.a").eq(col("t1.a")),
1001                    And,
1002                    col("t4.c").eq(lit(688u32)),
1003                ),
1004            ))?
1005            .build()?;
1006
1007        assert_optimized_plan_equal!(
1008            plan,
1009            @ r"
1010        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1011          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1012            Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1013              Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1014                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1015                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1016            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1017              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1018                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1019                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1020        "
1021        )
1022    }
1023
1024    #[test]
1025    fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1026        let t1 = test_table_scan_with_name("t1")?;
1027        let t2 = test_table_scan_with_name("t2")?;
1028        let t3 = test_table_scan_with_name("t3")?;
1029        let t4 = test_table_scan_with_name("t4")?;
1030
1031        // could eliminate to inner join
1032        // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
1033        let plan1 = LogicalPlanBuilder::from(t1)
1034            .cross_join(t2)?
1035            .filter(binary_expr(
1036                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1037                And,
1038                binary_expr(
1039                    col("t1.a").eq(col("t2.a")),
1040                    And,
1041                    col("t2.c").eq(lit(688u32)),
1042                ),
1043            ))?
1044            .build()?;
1045
1046        // could eliminate to inner join
1047        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1048
1049        // could eliminate to inner join
1050        // filter:
1051        //   ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1052        //     AND
1053        //   ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1054        let plan = LogicalPlanBuilder::from(plan1)
1055            .cross_join(plan2)?
1056            .filter(binary_expr(
1057                binary_expr(
1058                    binary_expr(
1059                        col("t3.a").eq(col("t1.a")),
1060                        And,
1061                        col("t4.c").lt(lit(15u32)),
1062                    ),
1063                    Or,
1064                    binary_expr(
1065                        col("t3.a").eq(col("t1.a")),
1066                        And,
1067                        col("t4.c").eq(lit(688u32)),
1068                    ),
1069                ),
1070                And,
1071                binary_expr(
1072                    binary_expr(
1073                        binary_expr(
1074                            col("t3.a").eq(col("t4.a")),
1075                            And,
1076                            col("t4.c").lt(lit(15u32)),
1077                        ),
1078                        Or,
1079                        binary_expr(
1080                            col("t3.a").eq(col("t4.a")),
1081                            And,
1082                            col("t3.c").eq(lit(688u32)),
1083                        ),
1084                    ),
1085                    Or,
1086                    binary_expr(
1087                        col("t3.a").eq(col("t4.a")),
1088                        And,
1089                        col("t3.b").eq(col("t4.b")),
1090                    ),
1091                ),
1092            ))?
1093            .build()?;
1094
1095        assert_optimized_plan_equal!(
1096            plan,
1097            @ r"
1098        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1099          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1100            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1101              Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1102                Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1103                  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1104                  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1105              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1106            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1107        "
1108        )
1109    }
1110
1111    #[test]
1112    fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1113        let t1 = test_table_scan_with_name("t1")?;
1114        let t2 = test_table_scan_with_name("t2")?;
1115        let t3 = test_table_scan_with_name("t3")?;
1116        let t4 = test_table_scan_with_name("t4")?;
1117
1118        // could eliminate to inner join
1119        let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1120
1121        // could eliminate to inner join
1122        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1123
1124        // could eliminate to inner join
1125        // Filter:
1126        //  ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1127        //      AND
1128        //  ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1129        //      AND
1130        //  ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
1131        let plan = LogicalPlanBuilder::from(plan1)
1132            .cross_join(plan2)?
1133            .filter(binary_expr(
1134                binary_expr(
1135                    binary_expr(
1136                        binary_expr(
1137                            col("t3.a").eq(col("t1.a")),
1138                            And,
1139                            col("t4.c").lt(lit(15u32)),
1140                        ),
1141                        Or,
1142                        binary_expr(
1143                            col("t3.a").eq(col("t1.a")),
1144                            And,
1145                            col("t4.c").eq(lit(688u32)),
1146                        ),
1147                    ),
1148                    And,
1149                    binary_expr(
1150                        binary_expr(
1151                            binary_expr(
1152                                col("t3.a").eq(col("t4.a")),
1153                                And,
1154                                col("t4.c").lt(lit(15u32)),
1155                            ),
1156                            Or,
1157                            binary_expr(
1158                                col("t3.a").eq(col("t4.a")),
1159                                And,
1160                                col("t3.c").eq(lit(688u32)),
1161                            ),
1162                        ),
1163                        Or,
1164                        binary_expr(
1165                            col("t3.a").eq(col("t4.a")),
1166                            And,
1167                            col("t3.b").eq(col("t4.b")),
1168                        ),
1169                    ),
1170                ),
1171                And,
1172                binary_expr(
1173                    binary_expr(
1174                        col("t1.a").eq(col("t2.a")),
1175                        Or,
1176                        col("t2.c").lt(lit(15u32)),
1177                    ),
1178                    And,
1179                    binary_expr(
1180                        col("t1.a").eq(col("t2.a")),
1181                        And,
1182                        col("t2.c").eq(lit(688u32)),
1183                    ),
1184                ),
1185            ))?
1186            .build()?;
1187
1188        assert_optimized_plan_equal!(
1189            plan,
1190            @ r"
1191        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1192          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1193            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1194              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1195                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1196                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1197              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1198            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1199        "
1200        )
1201    }
1202
1203    #[test]
1204    fn eliminate_cross_join_with_expr_and() -> Result<()> {
1205        let t1 = test_table_scan_with_name("t1")?;
1206        let t2 = test_table_scan_with_name("t2")?;
1207
1208        // could eliminate to inner join since filter has Join predicates
1209        let plan = LogicalPlanBuilder::from(t1)
1210            .cross_join(t2)?
1211            .filter(binary_expr(
1212                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1213                And,
1214                col("t2.c").lt(lit(20u32)),
1215            ))?
1216            .build()?;
1217
1218        assert_optimized_plan_equal!(
1219            plan,
1220            @ r"
1221        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1222          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1223            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1224            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1225        "
1226        )
1227    }
1228
1229    #[test]
1230    fn eliminate_cross_with_expr_or() -> Result<()> {
1231        let t1 = test_table_scan_with_name("t1")?;
1232        let t2 = test_table_scan_with_name("t2")?;
1233
1234        // could not eliminate to inner join since filter OR expression and there is no common
1235        // Join predicates in left and right of OR expr.
1236        let plan = LogicalPlanBuilder::from(t1)
1237            .cross_join(t2)?
1238            .filter(binary_expr(
1239                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1240                Or,
1241                col("t2.b").eq(col("t1.a")),
1242            ))?
1243            .build()?;
1244
1245        assert_optimized_plan_equal!(
1246            plan,
1247            @ r"
1248        Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1249          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1250            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1251            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1252        "
1253        )
1254    }
1255
1256    #[test]
1257    fn eliminate_cross_with_common_expr_and() -> Result<()> {
1258        let t1 = test_table_scan_with_name("t1")?;
1259        let t2 = test_table_scan_with_name("t2")?;
1260
1261        // could eliminate to inner join
1262        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1263        let plan = LogicalPlanBuilder::from(t1)
1264            .cross_join(t2)?
1265            .filter(binary_expr(
1266                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1267                And,
1268                binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1269            ))?
1270            .build()?;
1271
1272        assert_optimized_plan_equal!(
1273            plan,
1274            @ r"
1275        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1276          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1277            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1278            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1279        "
1280        )
1281    }
1282
1283    #[test]
1284    fn eliminate_cross_with_common_expr_or() -> Result<()> {
1285        let t1 = test_table_scan_with_name("t1")?;
1286        let t2 = test_table_scan_with_name("t2")?;
1287
1288        // could eliminate to inner join since Or predicates have common Join predicates
1289        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1290        let plan = LogicalPlanBuilder::from(t1)
1291            .cross_join(t2)?
1292            .filter(binary_expr(
1293                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1294                Or,
1295                binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1296            ))?
1297            .build()?;
1298
1299        assert_optimized_plan_equal!(
1300            plan,
1301            @ r"
1302        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1303          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1304            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1305            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1306        "
1307        )
1308    }
1309
1310    #[test]
1311    fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1312        let t1 = test_table_scan_with_name("t1")?;
1313        let t2 = test_table_scan_with_name("t2")?;
1314        let t3 = test_table_scan_with_name("t3")?;
1315
1316        // could eliminate to inner join
1317        let plan = LogicalPlanBuilder::from(t1)
1318            .cross_join(t2)?
1319            .cross_join(t3)?
1320            .filter(binary_expr(
1321                binary_expr(
1322                    (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1323                    And,
1324                    col("t3.c").lt(lit(15u32)),
1325                ),
1326                And,
1327                binary_expr(
1328                    (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1329                    And,
1330                    col("t3.b").lt(lit(15u32)),
1331                ),
1332            ))?
1333            .build()?;
1334
1335        assert_optimized_plan_equal!(
1336            plan,
1337            @ r"
1338        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1339          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1340            Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1341              Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1342                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1343                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1344              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1345        "
1346        )
1347    }
1348
1349    #[test]
1350    fn preserve_null_equality_setting() -> Result<()> {
1351        let t1 = test_table_scan_with_name("t1")?;
1352        let t2 = test_table_scan_with_name("t2")?;
1353
1354        // Create an inner join with NullEquality::NullEqualsNull
1355        let join_schema = Arc::new(build_join_schema(
1356            t1.schema(),
1357            t2.schema(),
1358            &JoinType::Inner,
1359        )?);
1360
1361        let inner_join = LogicalPlan::Join(Join {
1362            left: Arc::new(t1),
1363            right: Arc::new(t2),
1364            join_type: JoinType::Inner,
1365            join_constraint: JoinConstraint::On,
1366            on: vec![],
1367            filter: None,
1368            schema: join_schema,
1369            null_equality: NullEquality::NullEqualsNull, // Test preservation
1370        });
1371
1372        // Apply filter that can create join conditions
1373        let plan = LogicalPlanBuilder::from(inner_join)
1374            .filter(binary_expr(
1375                col("t1.a").eq(col("t2.a")),
1376                And,
1377                col("t2.c").lt(lit(20u32)),
1378            ))?
1379            .build()?;
1380
1381        let rule = EliminateCrossJoin::new();
1382        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1383
1384        // Verify that null_equality is preserved in the optimized plan
1385        fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
1386            match plan {
1387                LogicalPlan::Join(join) => {
1388                    // All joins in the optimized plan should preserve null equality
1389                    if join.null_equality == NullEquality::NullEqualsNothing {
1390                        return false;
1391                    }
1392                    // Recursively check child plans
1393                    plan.inputs()
1394                        .iter()
1395                        .all(|input| check_null_equality_preserved(input))
1396                }
1397                _ => {
1398                    // Recursively check child plans for non-join nodes
1399                    plan.inputs()
1400                        .iter()
1401                        .all(|input| check_null_equality_preserved(input))
1402                }
1403            }
1404        }
1405
1406        assert!(
1407            check_null_equality_preserved(&optimized_plan),
1408            "null_equality setting should be preserved after optimization"
1409        );
1410
1411        Ok(())
1412    }
1413}