datafusion_optimizer/
eliminate_cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27    Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36    #[allow(missing_docs)]
37    pub fn new() -> Self {
38        Self {}
39    }
40}
41
42/// Eliminate cross joins by rewriting them to inner joins when possible.
43///
44/// # Example
45/// The initial plan for this query:
46/// ```sql
47/// select ... from a, b where a.x = b.y and b.xx = 100;
48/// ```
49///
50/// Looks like this:
51/// ```text
52/// Filter(a.x = b.y AND b.xx = 100)
53///  Cross Join
54///   TableScan a
55///   TableScan b
56/// ```
57///
58/// After the rule is applied, the plan will look like this:
59/// ```text
60/// Filter(b.xx = 100)
61///   InnerJoin(a.x = b.y)
62///     TableScan a
63///     TableScan b
64/// ```
65///
66/// # Other Examples
67/// * 'select ... from a, b where a.x = b.y and b.xx = 100;'
68/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
69/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
70/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);'
71/// * 'select ... from a, b where a.x > b.y'
72///
73/// For above queries, the join predicate is available in filters and they are moved to
74/// join nodes appropriately
75///
76/// This fix helps to improve the performance of TPCH Q19. issue#78
77impl OptimizerRule for EliminateCrossJoin {
78    fn supports_rewrite(&self) -> bool {
79        true
80    }
81
82    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83    fn rewrite(
84        &self,
85        plan: LogicalPlan,
86        config: &dyn OptimizerConfig,
87    ) -> Result<Transformed<LogicalPlan>> {
88        let plan_schema = Arc::clone(plan.schema());
89        let mut possible_join_keys = JoinKeySet::new();
90        let mut all_inputs: Vec<LogicalPlan> = vec![];
91        let mut all_filters: Vec<Expr> = vec![];
92        let mut null_equality = NullEquality::NullEqualsNothing;
93
94        let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
95            // if input isn't a join that can potentially be rewritten
96            // avoid unwrapping the input
97            let rewritable = matches!(
98                filter.input.as_ref(),
99                LogicalPlan::Join(Join {
100                    join_type: JoinType::Inner,
101                    ..
102                })
103            );
104
105            if !rewritable {
106                // recursively try to rewrite children
107                return rewrite_children(self, LogicalPlan::Filter(filter), config);
108            }
109
110            if !can_flatten_join_inputs(&filter.input) {
111                return Ok(Transformed::no(LogicalPlan::Filter(filter)));
112            }
113
114            let Filter {
115                input, predicate, ..
116            } = filter;
117
118            // Extract null_equality setting from the input join
119            if let LogicalPlan::Join(join) = input.as_ref() {
120                null_equality = join.null_equality;
121            }
122
123            flatten_join_inputs(
124                Arc::unwrap_or_clone(input),
125                &mut possible_join_keys,
126                &mut all_inputs,
127                &mut all_filters,
128            )?;
129
130            extract_possible_join_keys(&predicate, &mut possible_join_keys);
131            Some(predicate)
132        } else {
133            match plan {
134                LogicalPlan::Join(Join {
135                    join_type: JoinType::Inner,
136                    null_equality: original_null_equality,
137                    ..
138                }) => {
139                    if !can_flatten_join_inputs(&plan) {
140                        return Ok(Transformed::no(plan));
141                    }
142                    flatten_join_inputs(
143                        plan,
144                        &mut possible_join_keys,
145                        &mut all_inputs,
146                        &mut all_filters,
147                    )?;
148                    null_equality = original_null_equality;
149                    None
150                }
151                _ => {
152                    // recursively try to rewrite children
153                    return rewrite_children(self, plan, config);
154                }
155            }
156        };
157
158        // Join keys are handled locally:
159        let mut all_join_keys = JoinKeySet::new();
160        let mut left = all_inputs.remove(0);
161        while !all_inputs.is_empty() {
162            left = find_inner_join(
163                left,
164                &mut all_inputs,
165                &possible_join_keys,
166                &mut all_join_keys,
167                null_equality,
168            )?;
169        }
170
171        left = rewrite_children(self, left, config)?.data;
172
173        if &plan_schema != left.schema() {
174            left = LogicalPlan::Projection(Projection::new_from_schema(
175                Arc::new(left),
176                Arc::clone(&plan_schema),
177            ));
178        }
179
180        if !all_filters.is_empty() {
181            // Add any filters on top - PushDownFilter can push filters down to applicable join
182            let first = all_filters.swap_remove(0);
183            let predicate = all_filters.into_iter().fold(first, and);
184            left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
185        }
186
187        let Some(predicate) = parent_predicate else {
188            return Ok(Transformed::yes(left));
189        };
190
191        // If there are no join keys then do nothing:
192        if all_join_keys.is_empty() {
193            Filter::try_new(predicate, Arc::new(left))
194                .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
195        } else {
196            // Remove join expressions from filter:
197            match remove_join_expressions(predicate, &all_join_keys) {
198                Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
199                    .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
200                _ => Ok(Transformed::yes(left)),
201            }
202        }
203    }
204
205    fn name(&self) -> &str {
206        "eliminate_cross_join"
207    }
208}
209
210fn rewrite_children(
211    optimizer: &impl OptimizerRule,
212    plan: LogicalPlan,
213    config: &dyn OptimizerConfig,
214) -> Result<Transformed<LogicalPlan>> {
215    let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
216
217    // recompute schema if the plan was transformed
218    if transformed_plan.transformed {
219        transformed_plan.map_data(|plan| plan.recompute_schema())
220    } else {
221        Ok(transformed_plan)
222    }
223}
224
225/// Recursively accumulate possible_join_keys and inputs from inner joins
226/// (including cross joins).
227///
228/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
229/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
230/// possible_join_keys
231fn flatten_join_inputs(
232    plan: LogicalPlan,
233    possible_join_keys: &mut JoinKeySet,
234    all_inputs: &mut Vec<LogicalPlan>,
235    all_filters: &mut Vec<Expr>,
236) -> Result<()> {
237    match plan {
238        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
239            if let Some(filter) = join.filter {
240                all_filters.push(filter);
241            }
242            possible_join_keys.insert_all_owned(join.on);
243            flatten_join_inputs(
244                Arc::unwrap_or_clone(join.left),
245                possible_join_keys,
246                all_inputs,
247                all_filters,
248            )?;
249            flatten_join_inputs(
250                Arc::unwrap_or_clone(join.right),
251                possible_join_keys,
252                all_inputs,
253                all_filters,
254            )?;
255        }
256        _ => {
257            all_inputs.push(plan);
258        }
259    };
260    Ok(())
261}
262
263/// Returns true if the plan is a Join or Cross join could be flattened with
264/// `flatten_join_inputs`
265///
266/// Must stay in sync with `flatten_join_inputs`
267fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
268    // can only flatten inner / cross joins
269    match plan {
270        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
271        _ => return false,
272    };
273
274    for child in plan.inputs() {
275        if let LogicalPlan::Join(Join {
276            join_type: JoinType::Inner,
277            ..
278        }) = child
279        {
280            if !can_flatten_join_inputs(child) {
281                return false;
282            }
283        }
284    }
285    true
286}
287
288/// Finds the next to join with the left input plan,
289///
290/// Finds the next `right` from `rights` that can be joined with `left_input`
291/// plan based on the join keys in `possible_join_keys`.
292///
293/// If such a matching `right` is found:
294/// 1. Adds the matching join keys to `all_join_keys`.
295/// 2. Returns `left_input JOIN right ON (all join keys)`.
296///
297/// If no matching `right` is found:
298/// 1. Removes the first plan from `rights`
299/// 2. Returns `left_input CROSS JOIN right`.
300fn find_inner_join(
301    left_input: LogicalPlan,
302    rights: &mut Vec<LogicalPlan>,
303    possible_join_keys: &JoinKeySet,
304    all_join_keys: &mut JoinKeySet,
305    null_equality: NullEquality,
306) -> Result<LogicalPlan> {
307    for (i, right_input) in rights.iter().enumerate() {
308        let mut join_keys = vec![];
309
310        for (l, r) in possible_join_keys.iter() {
311            let key_pair = find_valid_equijoin_key_pair(
312                l,
313                r,
314                left_input.schema(),
315                right_input.schema(),
316            )?;
317
318            // Save join keys
319            if let Some((valid_l, valid_r)) = key_pair {
320                if can_hash(&valid_l.get_type(left_input.schema())?) {
321                    join_keys.push((valid_l, valid_r));
322                }
323            }
324        }
325
326        // Found one or more matching join keys
327        if !join_keys.is_empty() {
328            all_join_keys.insert_all(join_keys.iter());
329            let right_input = rights.remove(i);
330            let join_schema = Arc::new(build_join_schema(
331                left_input.schema(),
332                right_input.schema(),
333                &JoinType::Inner,
334            )?);
335
336            return Ok(LogicalPlan::Join(Join {
337                left: Arc::new(left_input),
338                right: Arc::new(right_input),
339                join_type: JoinType::Inner,
340                join_constraint: JoinConstraint::On,
341                on: join_keys,
342                filter: None,
343                schema: join_schema,
344                null_equality,
345            }));
346        }
347    }
348
349    // no matching right plan had any join keys, cross join with the first right
350    // plan
351    let right = rights.remove(0);
352    let join_schema = Arc::new(build_join_schema(
353        left_input.schema(),
354        right.schema(),
355        &JoinType::Inner,
356    )?);
357
358    Ok(LogicalPlan::Join(Join {
359        left: Arc::new(left_input),
360        right: Arc::new(right),
361        schema: join_schema,
362        on: vec![],
363        filter: None,
364        join_type: JoinType::Inner,
365        join_constraint: JoinConstraint::On,
366        null_equality,
367    }))
368}
369
370/// Extract join keys from a WHERE clause
371fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
372    if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
373        match op {
374            Operator::Eq => {
375                // insert handles ensuring  we don't add the same Join keys multiple times
376                join_keys.insert(left, right);
377            }
378            Operator::And => {
379                extract_possible_join_keys(left, join_keys);
380                extract_possible_join_keys(right, join_keys)
381            }
382            // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
383            Operator::Or => {
384                let mut left_join_keys = JoinKeySet::new();
385                let mut right_join_keys = JoinKeySet::new();
386
387                extract_possible_join_keys(left, &mut left_join_keys);
388                extract_possible_join_keys(right, &mut right_join_keys);
389
390                join_keys.insert_intersection(&left_join_keys, &right_join_keys)
391            }
392            _ => (),
393        };
394    }
395}
396
397/// Remove join expressions from a filter expression
398///
399/// # Returns
400/// * `Some()` when there are few remaining predicates in filter_expr
401/// * `None` otherwise
402fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
403    match expr {
404        Expr::BinaryExpr(BinaryExpr {
405            left,
406            op: Operator::Eq,
407            right,
408        }) if join_keys.contains(&left, &right) => {
409            // was a join key, so remove it
410            None
411        }
412        // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
413        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
414            let l = remove_join_expressions(*left, join_keys);
415            let r = remove_join_expressions(*right, join_keys);
416            match (l, r) {
417                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
418                    Box::new(ll),
419                    op,
420                    Box::new(rr),
421                ))),
422                (Some(ll), _) => Some(ll),
423                (_, Some(rr)) => Some(rr),
424                _ => None,
425            }
426        }
427        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
428            let l = remove_join_expressions(*left, join_keys);
429            let r = remove_join_expressions(*right, join_keys);
430            match (l, r) {
431                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
432                    Box::new(ll),
433                    op,
434                    Box::new(rr),
435                ))),
436                // When either `left` or `right` is empty, it means they are `true`
437                // so OR'ing anything with them will also be true
438                _ => None,
439            }
440        }
441        _ => Some(expr),
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::optimizer::OptimizerContext;
449    use crate::test::*;
450
451    use datafusion_expr::{
452        binary_expr, col, lit,
453        logical_plan::builder::LogicalPlanBuilder,
454        Operator::{And, Or},
455    };
456    use insta::assert_snapshot;
457
458    macro_rules! assert_optimized_plan_equal {
459        (
460            $plan:expr,
461            @ $expected:literal $(,)?
462        ) => {{
463            let starting_schema = Arc::clone($plan.schema());
464            let rule = EliminateCrossJoin::new();
465            let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
466            let formatted_plan = optimized_plan.display_indent_schema();
467            // Ensure the rule was actually applied
468            assert!(is_plan_transformed, "failed to optimize plan");
469            // Verify the schema remains unchanged
470            assert_eq!(&starting_schema, optimized_plan.schema());
471            assert_snapshot!(
472                formatted_plan,
473                @ $expected,
474            );
475
476            Ok(())
477        }};
478    }
479
480    #[test]
481    fn eliminate_cross_with_simple_and() -> Result<()> {
482        let t1 = test_table_scan_with_name("t1")?;
483        let t2 = test_table_scan_with_name("t2")?;
484
485        // could eliminate to inner join since filter has Join predicates
486        let plan = LogicalPlanBuilder::from(t1)
487            .cross_join(t2)?
488            .filter(binary_expr(
489                col("t1.a").eq(col("t2.a")),
490                And,
491                col("t2.c").lt(lit(20u32)),
492            ))?
493            .build()?;
494
495        assert_optimized_plan_equal!(
496            plan,
497            @ r"
498        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
499          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
500            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
501            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
502        "
503        )
504    }
505
506    #[test]
507    fn eliminate_cross_with_simple_or() -> Result<()> {
508        let t1 = test_table_scan_with_name("t1")?;
509        let t2 = test_table_scan_with_name("t2")?;
510
511        // could not eliminate to inner join since filter OR expression and there is no common
512        // Join predicates in left and right of OR expr.
513        let plan = LogicalPlanBuilder::from(t1)
514            .cross_join(t2)?
515            .filter(binary_expr(
516                col("t1.a").eq(col("t2.a")),
517                Or,
518                col("t2.b").eq(col("t1.a")),
519            ))?
520            .build()?;
521
522        assert_optimized_plan_equal!(
523            plan,
524            @ r"
525        Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
526          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
527            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
528            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
529        "
530        )
531    }
532
533    #[test]
534    fn eliminate_cross_with_and() -> Result<()> {
535        let t1 = test_table_scan_with_name("t1")?;
536        let t2 = test_table_scan_with_name("t2")?;
537
538        // could eliminate to inner join
539        let plan = LogicalPlanBuilder::from(t1)
540            .cross_join(t2)?
541            .filter(binary_expr(
542                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
543                And,
544                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
545            ))?
546            .build()?;
547
548        assert_optimized_plan_equal!(
549            plan,
550            @ r"
551        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
552          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
553            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
554            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
555        "
556        )
557    }
558
559    #[test]
560    fn eliminate_cross_with_or() -> Result<()> {
561        let t1 = test_table_scan_with_name("t1")?;
562        let t2 = test_table_scan_with_name("t2")?;
563
564        // could eliminate to inner join since Or predicates have common Join predicates
565        let plan = LogicalPlanBuilder::from(t1)
566            .cross_join(t2)?
567            .filter(binary_expr(
568                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
569                Or,
570                binary_expr(
571                    col("t1.a").eq(col("t2.a")),
572                    And,
573                    col("t2.c").eq(lit(688u32)),
574                ),
575            ))?
576            .build()?;
577
578        assert_optimized_plan_equal!(
579            plan,
580            @ r"
581        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
582          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
583            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
584            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
585        "
586        )
587    }
588
589    #[test]
590    fn eliminate_cross_not_possible_simple() -> Result<()> {
591        let t1 = test_table_scan_with_name("t1")?;
592        let t2 = test_table_scan_with_name("t2")?;
593
594        // could not eliminate to inner join
595        let plan = LogicalPlanBuilder::from(t1)
596            .cross_join(t2)?
597            .filter(binary_expr(
598                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
599                Or,
600                binary_expr(
601                    col("t1.b").eq(col("t2.b")),
602                    And,
603                    col("t2.c").eq(lit(688u32)),
604                ),
605            ))?
606            .build()?;
607
608        assert_optimized_plan_equal!(
609            plan,
610            @ r"
611        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
612          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
613            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
614            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
615        "
616        )
617    }
618
619    #[test]
620    fn eliminate_cross_not_possible() -> Result<()> {
621        let t1 = test_table_scan_with_name("t1")?;
622        let t2 = test_table_scan_with_name("t2")?;
623
624        // could not eliminate to inner join
625        let plan = LogicalPlanBuilder::from(t1)
626            .cross_join(t2)?
627            .filter(binary_expr(
628                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
629                Or,
630                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
631            ))?
632            .build()?;
633
634        assert_optimized_plan_equal!(
635            plan,
636            @ r"
637        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
638          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
639            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
640            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
641        "
642        )
643    }
644
645    #[test]
646    fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
647        let t1 = test_table_scan_with_name("t1")?;
648        let t2 = test_table_scan_with_name("t2")?;
649        let t3 = test_table_scan_with_name("t3")?;
650
651        // could not eliminate to inner join with filter
652        let plan = LogicalPlanBuilder::from(t1)
653            .join(
654                t3,
655                JoinType::Inner,
656                (vec!["t1.a"], vec!["t3.a"]),
657                Some(col("t1.a").gt(lit(20u32))),
658            )?
659            .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
660            .filter(col("t1.a").gt(lit(15u32)))?
661            .build()?;
662
663        assert_optimized_plan_equal!(
664            plan,
665            @ r"
666        Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
667          Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
668            Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
669              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
670                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
671                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
672              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
673        "
674        )
675    }
676
677    #[test]
678    /// ```txt
679    /// filter: a.id = b.id and a.id = c.id
680    ///   cross_join a (bc)
681    ///     cross_join b c
682    /// ```
683    /// Without reorder, it will be
684    /// ```txt
685    ///   inner_join a (bc) on a.id = b.id and a.id = c.id
686    ///     cross_join b c
687    /// ```
688    /// Reorder it to be
689    /// ```txt
690    ///   inner_join (ab)c and a.id = c.id
691    ///     inner_join a b on a.id = b.id
692    /// ```
693    fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
694        let t1 = test_table_scan_with_name("t1")?;
695        let t2 = test_table_scan_with_name("t2")?;
696        let t3 = test_table_scan_with_name("t3")?;
697
698        // could eliminate to inner join
699        let plan = LogicalPlanBuilder::from(t1)
700            .cross_join(t2)?
701            .cross_join(t3)?
702            .filter(binary_expr(
703                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
704                And,
705                binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
706            ))?
707            .build()?;
708
709        assert_optimized_plan_equal!(
710            plan,
711            @ r"
712        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
713          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
714            Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
715              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
716                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
717                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
718              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
719        "
720        )
721    }
722
723    #[test]
724    fn eliminate_cross_join_multi_tables() -> Result<()> {
725        let t1 = test_table_scan_with_name("t1")?;
726        let t2 = test_table_scan_with_name("t2")?;
727        let t3 = test_table_scan_with_name("t3")?;
728        let t4 = test_table_scan_with_name("t4")?;
729
730        // could eliminate to inner join
731        let plan1 = LogicalPlanBuilder::from(t1)
732            .cross_join(t2)?
733            .filter(binary_expr(
734                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
735                Or,
736                binary_expr(
737                    col("t1.a").eq(col("t2.a")),
738                    And,
739                    col("t2.c").eq(lit(688u32)),
740                ),
741            ))?
742            .build()?;
743
744        let plan2 = LogicalPlanBuilder::from(t3)
745            .cross_join(t4)?
746            .filter(binary_expr(
747                binary_expr(
748                    binary_expr(
749                        col("t3.a").eq(col("t4.a")),
750                        And,
751                        col("t4.c").lt(lit(15u32)),
752                    ),
753                    Or,
754                    binary_expr(
755                        col("t3.a").eq(col("t4.a")),
756                        And,
757                        col("t3.c").eq(lit(688u32)),
758                    ),
759                ),
760                Or,
761                binary_expr(
762                    col("t3.a").eq(col("t4.a")),
763                    And,
764                    col("t3.b").eq(col("t4.b")),
765                ),
766            ))?
767            .build()?;
768
769        let plan = LogicalPlanBuilder::from(plan1)
770            .cross_join(plan2)?
771            .filter(binary_expr(
772                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
773                Or,
774                binary_expr(
775                    col("t3.a").eq(col("t1.a")),
776                    And,
777                    col("t4.c").eq(lit(688u32)),
778                ),
779            ))?
780            .build()?;
781
782        assert_optimized_plan_equal!(
783            plan,
784            @ r"
785        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
786          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
787            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
788              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
789                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
790                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
791            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
792              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
793                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
794                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
795        "
796        )
797    }
798
799    #[test]
800    fn eliminate_cross_join_multi_tables_1() -> Result<()> {
801        let t1 = test_table_scan_with_name("t1")?;
802        let t2 = test_table_scan_with_name("t2")?;
803        let t3 = test_table_scan_with_name("t3")?;
804        let t4 = test_table_scan_with_name("t4")?;
805
806        // could eliminate to inner join
807        let plan1 = LogicalPlanBuilder::from(t1)
808            .cross_join(t2)?
809            .filter(binary_expr(
810                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
811                Or,
812                binary_expr(
813                    col("t1.a").eq(col("t2.a")),
814                    And,
815                    col("t2.c").eq(lit(688u32)),
816                ),
817            ))?
818            .build()?;
819
820        // could eliminate to inner join
821        let plan2 = LogicalPlanBuilder::from(t3)
822            .cross_join(t4)?
823            .filter(binary_expr(
824                binary_expr(
825                    binary_expr(
826                        col("t3.a").eq(col("t4.a")),
827                        And,
828                        col("t4.c").lt(lit(15u32)),
829                    ),
830                    Or,
831                    binary_expr(
832                        col("t3.a").eq(col("t4.a")),
833                        And,
834                        col("t3.c").eq(lit(688u32)),
835                    ),
836                ),
837                Or,
838                binary_expr(
839                    col("t3.a").eq(col("t4.a")),
840                    And,
841                    col("t3.b").eq(col("t4.b")),
842                ),
843            ))?
844            .build()?;
845
846        // could not eliminate to inner join
847        let plan = LogicalPlanBuilder::from(plan1)
848            .cross_join(plan2)?
849            .filter(binary_expr(
850                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
851                Or,
852                binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
853            ))?
854            .build()?;
855
856        assert_optimized_plan_equal!(
857            plan,
858            @ r"
859        Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
860          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
861            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
862              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
863                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
864                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
865            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
866              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
867                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
868                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
869        "
870        )
871    }
872
873    #[test]
874    fn eliminate_cross_join_multi_tables_2() -> Result<()> {
875        let t1 = test_table_scan_with_name("t1")?;
876        let t2 = test_table_scan_with_name("t2")?;
877        let t3 = test_table_scan_with_name("t3")?;
878        let t4 = test_table_scan_with_name("t4")?;
879
880        // could eliminate to inner join
881        let plan1 = LogicalPlanBuilder::from(t1)
882            .cross_join(t2)?
883            .filter(binary_expr(
884                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
885                Or,
886                binary_expr(
887                    col("t1.a").eq(col("t2.a")),
888                    And,
889                    col("t2.c").eq(lit(688u32)),
890                ),
891            ))?
892            .build()?;
893
894        // could not eliminate to inner join
895        let plan2 = LogicalPlanBuilder::from(t3)
896            .cross_join(t4)?
897            .filter(binary_expr(
898                binary_expr(
899                    binary_expr(
900                        col("t3.a").eq(col("t4.a")),
901                        And,
902                        col("t4.c").lt(lit(15u32)),
903                    ),
904                    Or,
905                    binary_expr(
906                        col("t3.a").eq(col("t4.a")),
907                        And,
908                        col("t3.c").eq(lit(688u32)),
909                    ),
910                ),
911                Or,
912                binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
913            ))?
914            .build()?;
915
916        // could eliminate to inner join
917        let plan = LogicalPlanBuilder::from(plan1)
918            .cross_join(plan2)?
919            .filter(binary_expr(
920                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
921                Or,
922                binary_expr(
923                    col("t3.a").eq(col("t1.a")),
924                    And,
925                    col("t4.c").eq(lit(688u32)),
926                ),
927            ))?
928            .build()?;
929
930        assert_optimized_plan_equal!(
931            plan,
932            @ r"
933        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
934          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
935            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
936              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
937                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
938                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
939            Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
940              Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
941                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
942                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
943        "
944        )
945    }
946
947    #[test]
948    fn eliminate_cross_join_multi_tables_3() -> Result<()> {
949        let t1 = test_table_scan_with_name("t1")?;
950        let t2 = test_table_scan_with_name("t2")?;
951        let t3 = test_table_scan_with_name("t3")?;
952        let t4 = test_table_scan_with_name("t4")?;
953
954        // could not eliminate to inner join
955        let plan1 = LogicalPlanBuilder::from(t1)
956            .cross_join(t2)?
957            .filter(binary_expr(
958                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
959                Or,
960                binary_expr(
961                    col("t1.a").eq(col("t2.a")),
962                    And,
963                    col("t2.c").eq(lit(688u32)),
964                ),
965            ))?
966            .build()?;
967
968        // could eliminate to inner join
969        let plan2 = LogicalPlanBuilder::from(t3)
970            .cross_join(t4)?
971            .filter(binary_expr(
972                binary_expr(
973                    binary_expr(
974                        col("t3.a").eq(col("t4.a")),
975                        And,
976                        col("t4.c").lt(lit(15u32)),
977                    ),
978                    Or,
979                    binary_expr(
980                        col("t3.a").eq(col("t4.a")),
981                        And,
982                        col("t3.c").eq(lit(688u32)),
983                    ),
984                ),
985                Or,
986                binary_expr(
987                    col("t3.a").eq(col("t4.a")),
988                    And,
989                    col("t3.b").eq(col("t4.b")),
990                ),
991            ))?
992            .build()?;
993
994        // could eliminate to inner join
995        let plan = LogicalPlanBuilder::from(plan1)
996            .cross_join(plan2)?
997            .filter(binary_expr(
998                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
999                Or,
1000                binary_expr(
1001                    col("t3.a").eq(col("t1.a")),
1002                    And,
1003                    col("t4.c").eq(lit(688u32)),
1004                ),
1005            ))?
1006            .build()?;
1007
1008        assert_optimized_plan_equal!(
1009            plan,
1010            @ r"
1011        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1012          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1013            Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1014              Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1015                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1016                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1017            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1018              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1019                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1020                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1021        "
1022        )
1023    }
1024
1025    #[test]
1026    fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1027        let t1 = test_table_scan_with_name("t1")?;
1028        let t2 = test_table_scan_with_name("t2")?;
1029        let t3 = test_table_scan_with_name("t3")?;
1030        let t4 = test_table_scan_with_name("t4")?;
1031
1032        // could eliminate to inner join
1033        // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
1034        let plan1 = LogicalPlanBuilder::from(t1)
1035            .cross_join(t2)?
1036            .filter(binary_expr(
1037                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1038                And,
1039                binary_expr(
1040                    col("t1.a").eq(col("t2.a")),
1041                    And,
1042                    col("t2.c").eq(lit(688u32)),
1043                ),
1044            ))?
1045            .build()?;
1046
1047        // could eliminate to inner join
1048        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1049
1050        // could eliminate to inner join
1051        // filter:
1052        //   ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1053        //     AND
1054        //   ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1055        let plan = LogicalPlanBuilder::from(plan1)
1056            .cross_join(plan2)?
1057            .filter(binary_expr(
1058                binary_expr(
1059                    binary_expr(
1060                        col("t3.a").eq(col("t1.a")),
1061                        And,
1062                        col("t4.c").lt(lit(15u32)),
1063                    ),
1064                    Or,
1065                    binary_expr(
1066                        col("t3.a").eq(col("t1.a")),
1067                        And,
1068                        col("t4.c").eq(lit(688u32)),
1069                    ),
1070                ),
1071                And,
1072                binary_expr(
1073                    binary_expr(
1074                        binary_expr(
1075                            col("t3.a").eq(col("t4.a")),
1076                            And,
1077                            col("t4.c").lt(lit(15u32)),
1078                        ),
1079                        Or,
1080                        binary_expr(
1081                            col("t3.a").eq(col("t4.a")),
1082                            And,
1083                            col("t3.c").eq(lit(688u32)),
1084                        ),
1085                    ),
1086                    Or,
1087                    binary_expr(
1088                        col("t3.a").eq(col("t4.a")),
1089                        And,
1090                        col("t3.b").eq(col("t4.b")),
1091                    ),
1092                ),
1093            ))?
1094            .build()?;
1095
1096        assert_optimized_plan_equal!(
1097            plan,
1098            @ r"
1099        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1100          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1101            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1102              Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1103                Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1104                  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1105                  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1106              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1107            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1108        "
1109        )
1110    }
1111
1112    #[test]
1113    fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1114        let t1 = test_table_scan_with_name("t1")?;
1115        let t2 = test_table_scan_with_name("t2")?;
1116        let t3 = test_table_scan_with_name("t3")?;
1117        let t4 = test_table_scan_with_name("t4")?;
1118
1119        // could eliminate to inner join
1120        let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1121
1122        // could eliminate to inner join
1123        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1124
1125        // could eliminate to inner join
1126        // Filter:
1127        //  ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1128        //      AND
1129        //  ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1130        //      AND
1131        //  ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
1132        let plan = LogicalPlanBuilder::from(plan1)
1133            .cross_join(plan2)?
1134            .filter(binary_expr(
1135                binary_expr(
1136                    binary_expr(
1137                        binary_expr(
1138                            col("t3.a").eq(col("t1.a")),
1139                            And,
1140                            col("t4.c").lt(lit(15u32)),
1141                        ),
1142                        Or,
1143                        binary_expr(
1144                            col("t3.a").eq(col("t1.a")),
1145                            And,
1146                            col("t4.c").eq(lit(688u32)),
1147                        ),
1148                    ),
1149                    And,
1150                    binary_expr(
1151                        binary_expr(
1152                            binary_expr(
1153                                col("t3.a").eq(col("t4.a")),
1154                                And,
1155                                col("t4.c").lt(lit(15u32)),
1156                            ),
1157                            Or,
1158                            binary_expr(
1159                                col("t3.a").eq(col("t4.a")),
1160                                And,
1161                                col("t3.c").eq(lit(688u32)),
1162                            ),
1163                        ),
1164                        Or,
1165                        binary_expr(
1166                            col("t3.a").eq(col("t4.a")),
1167                            And,
1168                            col("t3.b").eq(col("t4.b")),
1169                        ),
1170                    ),
1171                ),
1172                And,
1173                binary_expr(
1174                    binary_expr(
1175                        col("t1.a").eq(col("t2.a")),
1176                        Or,
1177                        col("t2.c").lt(lit(15u32)),
1178                    ),
1179                    And,
1180                    binary_expr(
1181                        col("t1.a").eq(col("t2.a")),
1182                        And,
1183                        col("t2.c").eq(lit(688u32)),
1184                    ),
1185                ),
1186            ))?
1187            .build()?;
1188
1189        assert_optimized_plan_equal!(
1190            plan,
1191            @ r"
1192        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1193          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1194            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1195              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1196                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1197                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1198              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1199            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1200        "
1201        )
1202    }
1203
1204    #[test]
1205    fn eliminate_cross_join_with_expr_and() -> Result<()> {
1206        let t1 = test_table_scan_with_name("t1")?;
1207        let t2 = test_table_scan_with_name("t2")?;
1208
1209        // could eliminate to inner join since filter has Join predicates
1210        let plan = LogicalPlanBuilder::from(t1)
1211            .cross_join(t2)?
1212            .filter(binary_expr(
1213                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1214                And,
1215                col("t2.c").lt(lit(20u32)),
1216            ))?
1217            .build()?;
1218
1219        assert_optimized_plan_equal!(
1220            plan,
1221            @ r"
1222        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1223          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1224            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1225            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1226        "
1227        )
1228    }
1229
1230    #[test]
1231    fn eliminate_cross_with_expr_or() -> Result<()> {
1232        let t1 = test_table_scan_with_name("t1")?;
1233        let t2 = test_table_scan_with_name("t2")?;
1234
1235        // could not eliminate to inner join since filter OR expression and there is no common
1236        // Join predicates in left and right of OR expr.
1237        let plan = LogicalPlanBuilder::from(t1)
1238            .cross_join(t2)?
1239            .filter(binary_expr(
1240                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1241                Or,
1242                col("t2.b").eq(col("t1.a")),
1243            ))?
1244            .build()?;
1245
1246        assert_optimized_plan_equal!(
1247            plan,
1248            @ r"
1249        Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1250          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1251            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1252            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1253        "
1254        )
1255    }
1256
1257    #[test]
1258    fn eliminate_cross_with_common_expr_and() -> Result<()> {
1259        let t1 = test_table_scan_with_name("t1")?;
1260        let t2 = test_table_scan_with_name("t2")?;
1261
1262        // could eliminate to inner join
1263        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1264        let plan = LogicalPlanBuilder::from(t1)
1265            .cross_join(t2)?
1266            .filter(binary_expr(
1267                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1268                And,
1269                binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1270            ))?
1271            .build()?;
1272
1273        assert_optimized_plan_equal!(
1274            plan,
1275            @ r"
1276        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1277          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1278            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1279            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1280        "
1281        )
1282    }
1283
1284    #[test]
1285    fn eliminate_cross_with_common_expr_or() -> Result<()> {
1286        let t1 = test_table_scan_with_name("t1")?;
1287        let t2 = test_table_scan_with_name("t2")?;
1288
1289        // could eliminate to inner join since Or predicates have common Join predicates
1290        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1291        let plan = LogicalPlanBuilder::from(t1)
1292            .cross_join(t2)?
1293            .filter(binary_expr(
1294                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1295                Or,
1296                binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1297            ))?
1298            .build()?;
1299
1300        assert_optimized_plan_equal!(
1301            plan,
1302            @ r"
1303        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1304          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1305            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1306            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1307        "
1308        )
1309    }
1310
1311    #[test]
1312    fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1313        let t1 = test_table_scan_with_name("t1")?;
1314        let t2 = test_table_scan_with_name("t2")?;
1315        let t3 = test_table_scan_with_name("t3")?;
1316
1317        // could eliminate to inner join
1318        let plan = LogicalPlanBuilder::from(t1)
1319            .cross_join(t2)?
1320            .cross_join(t3)?
1321            .filter(binary_expr(
1322                binary_expr(
1323                    (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1324                    And,
1325                    col("t3.c").lt(lit(15u32)),
1326                ),
1327                And,
1328                binary_expr(
1329                    (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1330                    And,
1331                    col("t3.b").lt(lit(15u32)),
1332                ),
1333            ))?
1334            .build()?;
1335
1336        assert_optimized_plan_equal!(
1337            plan,
1338            @ r"
1339        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1340          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1341            Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1342              Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1343                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1344                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1345              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1346        "
1347        )
1348    }
1349
1350    #[test]
1351    fn preserve_null_equality_setting() -> Result<()> {
1352        let t1 = test_table_scan_with_name("t1")?;
1353        let t2 = test_table_scan_with_name("t2")?;
1354
1355        // Create an inner join with NullEquality::NullEqualsNull
1356        let join_schema = Arc::new(build_join_schema(
1357            t1.schema(),
1358            t2.schema(),
1359            &JoinType::Inner,
1360        )?);
1361
1362        let inner_join = LogicalPlan::Join(Join {
1363            left: Arc::new(t1),
1364            right: Arc::new(t2),
1365            join_type: JoinType::Inner,
1366            join_constraint: JoinConstraint::On,
1367            on: vec![],
1368            filter: None,
1369            schema: join_schema,
1370            null_equality: NullEquality::NullEqualsNull, // Test preservation
1371        });
1372
1373        // Apply filter that can create join conditions
1374        let plan = LogicalPlanBuilder::from(inner_join)
1375            .filter(binary_expr(
1376                col("t1.a").eq(col("t2.a")),
1377                And,
1378                col("t2.c").lt(lit(20u32)),
1379            ))?
1380            .build()?;
1381
1382        let rule = EliminateCrossJoin::new();
1383        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1384
1385        // Verify that null_equality is preserved in the optimized plan
1386        fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
1387            match plan {
1388                LogicalPlan::Join(join) => {
1389                    // All joins in the optimized plan should preserve null equality
1390                    if join.null_equality == NullEquality::NullEqualsNothing {
1391                        return false;
1392                    }
1393                    // Recursively check child plans
1394                    plan.inputs()
1395                        .iter()
1396                        .all(|input| check_null_equality_preserved(input))
1397                }
1398                _ => {
1399                    // Recursively check child plans for non-join nodes
1400                    plan.inputs()
1401                        .iter()
1402                        .all(|input| check_null_equality_preserved(input))
1403                }
1404            }
1405        }
1406
1407        assert!(
1408            check_null_equality_preserved(&optimized_plan),
1409            "null_equality setting should be preserved after optimization"
1410        );
1411
1412        Ok(())
1413    }
1414}