1use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27 Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{ExprSchemable, Operator, and, build_join_schema};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36 #[expect(missing_docs)]
37 pub fn new() -> Self {
38 Self {}
39 }
40}
41
42impl OptimizerRule for EliminateCrossJoin {
78 fn supports_rewrite(&self) -> bool {
79 true
80 }
81
82 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83 fn rewrite(
84 &self,
85 plan: LogicalPlan,
86 config: &dyn OptimizerConfig,
87 ) -> Result<Transformed<LogicalPlan>> {
88 let plan_schema = Arc::clone(plan.schema());
89 let mut possible_join_keys = JoinKeySet::new();
90 let mut all_inputs: Vec<LogicalPlan> = vec![];
91 let mut all_filters: Vec<Expr> = vec![];
92 let mut null_equality = NullEquality::NullEqualsNothing;
93
94 let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
95 let rewritable = matches!(
98 filter.input.as_ref(),
99 LogicalPlan::Join(Join {
100 join_type: JoinType::Inner,
101 ..
102 })
103 );
104
105 if !rewritable {
106 return rewrite_children(self, LogicalPlan::Filter(filter), config);
108 }
109
110 if !can_flatten_join_inputs(&filter.input) {
111 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
112 }
113
114 let Filter {
115 input, predicate, ..
116 } = filter;
117
118 if let LogicalPlan::Join(join) = input.as_ref() {
120 null_equality = join.null_equality;
121 }
122
123 flatten_join_inputs(
124 Arc::unwrap_or_clone(input),
125 &mut possible_join_keys,
126 &mut all_inputs,
127 &mut all_filters,
128 )?;
129
130 extract_possible_join_keys(&predicate, &mut possible_join_keys);
131 Some(predicate)
132 } else {
133 match plan {
134 LogicalPlan::Join(Join {
135 join_type: JoinType::Inner,
136 null_equality: original_null_equality,
137 ..
138 }) => {
139 if !can_flatten_join_inputs(&plan) {
140 return Ok(Transformed::no(plan));
141 }
142 flatten_join_inputs(
143 plan,
144 &mut possible_join_keys,
145 &mut all_inputs,
146 &mut all_filters,
147 )?;
148 null_equality = original_null_equality;
149 None
150 }
151 _ => {
152 return rewrite_children(self, plan, config);
154 }
155 }
156 };
157
158 let mut all_join_keys = JoinKeySet::new();
160 let mut left = all_inputs.remove(0);
161 while !all_inputs.is_empty() {
162 left = find_inner_join(
163 left,
164 &mut all_inputs,
165 &possible_join_keys,
166 &mut all_join_keys,
167 null_equality,
168 )?;
169 }
170
171 left = rewrite_children(self, left, config)?.data;
172
173 if &plan_schema != left.schema() {
174 left = LogicalPlan::Projection(Projection::new_from_schema(
175 Arc::new(left),
176 Arc::clone(&plan_schema),
177 ));
178 }
179
180 if !all_filters.is_empty() {
181 let first = all_filters.swap_remove(0);
183 let predicate = all_filters.into_iter().fold(first, and);
184 left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
185 }
186
187 let Some(predicate) = parent_predicate else {
188 return Ok(Transformed::yes(left));
189 };
190
191 if all_join_keys.is_empty() {
193 Filter::try_new(predicate, Arc::new(left))
194 .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
195 } else {
196 match remove_join_expressions(predicate, &all_join_keys) {
198 Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
199 .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
200 _ => Ok(Transformed::yes(left)),
201 }
202 }
203 }
204
205 fn name(&self) -> &str {
206 "eliminate_cross_join"
207 }
208}
209
210fn rewrite_children(
211 optimizer: &impl OptimizerRule,
212 plan: LogicalPlan,
213 config: &dyn OptimizerConfig,
214) -> Result<Transformed<LogicalPlan>> {
215 let transformed_plan = plan
217 .map_uncorrelated_subqueries(|input| optimizer.rewrite(input, config))?
218 .transform_sibling(|plan| {
219 plan.map_children(|input| optimizer.rewrite(input, config))
220 })?;
221
222 if transformed_plan.transformed {
224 transformed_plan.map_data(|plan| plan.recompute_schema())
225 } else {
226 Ok(transformed_plan)
227 }
228}
229
230fn flatten_join_inputs(
237 plan: LogicalPlan,
238 possible_join_keys: &mut JoinKeySet,
239 all_inputs: &mut Vec<LogicalPlan>,
240 all_filters: &mut Vec<Expr>,
241) -> Result<()> {
242 match plan {
243 LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
244 if let Some(filter) = join.filter {
245 all_filters.push(filter);
246 }
247 possible_join_keys.insert_all_owned(join.on);
248 flatten_join_inputs(
249 Arc::unwrap_or_clone(join.left),
250 possible_join_keys,
251 all_inputs,
252 all_filters,
253 )?;
254 flatten_join_inputs(
255 Arc::unwrap_or_clone(join.right),
256 possible_join_keys,
257 all_inputs,
258 all_filters,
259 )?;
260 }
261 _ => {
262 all_inputs.push(plan);
263 }
264 };
265 Ok(())
266}
267
268fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
273 match plan {
275 LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
276 _ => return false,
277 };
278
279 for child in plan.inputs() {
280 if let LogicalPlan::Join(Join {
281 join_type: JoinType::Inner,
282 ..
283 }) = child
284 && !can_flatten_join_inputs(child)
285 {
286 return false;
287 }
288 }
289 true
290}
291
292fn find_inner_join(
305 left_input: LogicalPlan,
306 rights: &mut Vec<LogicalPlan>,
307 possible_join_keys: &JoinKeySet,
308 all_join_keys: &mut JoinKeySet,
309 null_equality: NullEquality,
310) -> Result<LogicalPlan> {
311 for (i, right_input) in rights.iter().enumerate() {
312 let mut join_keys = vec![];
313
314 for (l, r) in possible_join_keys.iter() {
315 let key_pair = find_valid_equijoin_key_pair(
316 l,
317 r,
318 left_input.schema(),
319 right_input.schema(),
320 )?;
321
322 if let Some((valid_l, valid_r)) = key_pair
324 && can_hash(&valid_l.get_type(left_input.schema())?)
325 {
326 join_keys.push((valid_l, valid_r));
327 }
328 }
329
330 if !join_keys.is_empty() {
332 all_join_keys.insert_all(join_keys.iter());
333 let right_input = rights.remove(i);
334 let join_schema = Arc::new(build_join_schema(
335 left_input.schema(),
336 right_input.schema(),
337 &JoinType::Inner,
338 )?);
339
340 return Ok(LogicalPlan::Join(Join {
341 left: Arc::new(left_input),
342 right: Arc::new(right_input),
343 join_type: JoinType::Inner,
344 join_constraint: JoinConstraint::On,
345 on: join_keys,
346 filter: None,
347 schema: join_schema,
348 null_equality,
349 null_aware: false,
350 }));
351 }
352 }
353
354 let right = rights.remove(0);
357 let join_schema = Arc::new(build_join_schema(
358 left_input.schema(),
359 right.schema(),
360 &JoinType::Inner,
361 )?);
362
363 Ok(LogicalPlan::Join(Join {
364 left: Arc::new(left_input),
365 right: Arc::new(right),
366 schema: join_schema,
367 on: vec![],
368 filter: None,
369 join_type: JoinType::Inner,
370 join_constraint: JoinConstraint::On,
371 null_equality,
372 null_aware: false,
373 }))
374}
375
376fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
378 if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
379 match op {
380 Operator::Eq => {
381 join_keys.insert(left, right);
383 }
384 Operator::And => {
385 extract_possible_join_keys(left, join_keys);
386 extract_possible_join_keys(right, join_keys)
387 }
388 Operator::Or => {
390 let mut left_join_keys = JoinKeySet::new();
391 let mut right_join_keys = JoinKeySet::new();
392
393 extract_possible_join_keys(left, &mut left_join_keys);
394 extract_possible_join_keys(right, &mut right_join_keys);
395
396 join_keys.insert_intersection(&left_join_keys, &right_join_keys)
397 }
398 _ => (),
399 };
400 }
401}
402
403fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
409 match expr {
410 Expr::BinaryExpr(BinaryExpr {
411 left,
412 op: Operator::Eq,
413 right,
414 }) if join_keys.contains(&left, &right) => {
415 None
417 }
418 Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
420 let l = remove_join_expressions(*left, join_keys);
421 let r = remove_join_expressions(*right, join_keys);
422 match (l, r) {
423 (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
424 Box::new(ll),
425 op,
426 Box::new(rr),
427 ))),
428 (Some(ll), _) => Some(ll),
429 (_, Some(rr)) => Some(rr),
430 _ => None,
431 }
432 }
433 Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
434 let l = remove_join_expressions(*left, join_keys);
435 let r = remove_join_expressions(*right, join_keys);
436 match (l, r) {
437 (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
438 Box::new(ll),
439 op,
440 Box::new(rr),
441 ))),
442 _ => None,
445 }
446 }
447 _ => Some(expr),
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use crate::optimizer::OptimizerContext;
455 use crate::test::*;
456
457 use datafusion_expr::{
458 Operator::{And, Or},
459 binary_expr, col, lit,
460 logical_plan::builder::LogicalPlanBuilder,
461 };
462 use insta::assert_snapshot;
463
464 macro_rules! assert_optimized_plan_equal {
465 (
466 $plan:expr,
467 @ $expected:literal $(,)?
468 ) => {{
469 let starting_schema = Arc::clone($plan.schema());
470 let rule = EliminateCrossJoin::new();
471 let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
472 let formatted_plan = optimized_plan.display_indent_schema();
473 assert!(is_plan_transformed, "failed to optimize plan");
475 assert_eq!(&starting_schema, optimized_plan.schema());
477 assert_snapshot!(
478 formatted_plan,
479 @ $expected,
480 );
481
482 Ok(())
483 }};
484 }
485
486 #[test]
487 fn eliminate_cross_with_simple_and() -> Result<()> {
488 let t1 = test_table_scan_with_name("t1")?;
489 let t2 = test_table_scan_with_name("t2")?;
490
491 let plan = LogicalPlanBuilder::from(t1)
493 .cross_join(t2)?
494 .filter(binary_expr(
495 col("t1.a").eq(col("t2.a")),
496 And,
497 col("t2.c").lt(lit(20u32)),
498 ))?
499 .build()?;
500
501 assert_optimized_plan_equal!(
502 plan,
503 @ r"
504 Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
505 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
506 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
507 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
508 "
509 )
510 }
511
512 #[test]
513 fn eliminate_cross_with_simple_or() -> Result<()> {
514 let t1 = test_table_scan_with_name("t1")?;
515 let t2 = test_table_scan_with_name("t2")?;
516
517 let plan = LogicalPlanBuilder::from(t1)
520 .cross_join(t2)?
521 .filter(binary_expr(
522 col("t1.a").eq(col("t2.a")),
523 Or,
524 col("t2.b").eq(col("t1.a")),
525 ))?
526 .build()?;
527
528 assert_optimized_plan_equal!(
529 plan,
530 @ r"
531 Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
532 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
533 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
534 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
535 "
536 )
537 }
538
539 #[test]
540 fn eliminate_cross_with_and() -> Result<()> {
541 let t1 = test_table_scan_with_name("t1")?;
542 let t2 = test_table_scan_with_name("t2")?;
543
544 let plan = LogicalPlanBuilder::from(t1)
546 .cross_join(t2)?
547 .filter(binary_expr(
548 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
549 And,
550 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
551 ))?
552 .build()?;
553
554 assert_optimized_plan_equal!(
555 plan,
556 @ r"
557 Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
558 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
559 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
560 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
561 "
562 )
563 }
564
565 #[test]
566 fn eliminate_cross_with_or() -> Result<()> {
567 let t1 = test_table_scan_with_name("t1")?;
568 let t2 = test_table_scan_with_name("t2")?;
569
570 let plan = LogicalPlanBuilder::from(t1)
572 .cross_join(t2)?
573 .filter(binary_expr(
574 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
575 Or,
576 binary_expr(
577 col("t1.a").eq(col("t2.a")),
578 And,
579 col("t2.c").eq(lit(688u32)),
580 ),
581 ))?
582 .build()?;
583
584 assert_optimized_plan_equal!(
585 plan,
586 @ r"
587 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
588 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
589 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
590 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
591 "
592 )
593 }
594
595 #[test]
596 fn eliminate_cross_not_possible_simple() -> Result<()> {
597 let t1 = test_table_scan_with_name("t1")?;
598 let t2 = test_table_scan_with_name("t2")?;
599
600 let plan = LogicalPlanBuilder::from(t1)
602 .cross_join(t2)?
603 .filter(binary_expr(
604 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
605 Or,
606 binary_expr(
607 col("t1.b").eq(col("t2.b")),
608 And,
609 col("t2.c").eq(lit(688u32)),
610 ),
611 ))?
612 .build()?;
613
614 assert_optimized_plan_equal!(
615 plan,
616 @ r"
617 Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
618 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
619 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
620 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
621 "
622 )
623 }
624
625 #[test]
626 fn eliminate_cross_not_possible() -> Result<()> {
627 let t1 = test_table_scan_with_name("t1")?;
628 let t2 = test_table_scan_with_name("t2")?;
629
630 let plan = LogicalPlanBuilder::from(t1)
632 .cross_join(t2)?
633 .filter(binary_expr(
634 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
635 Or,
636 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
637 ))?
638 .build()?;
639
640 assert_optimized_plan_equal!(
641 plan,
642 @ r"
643 Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
644 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
645 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
646 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
647 "
648 )
649 }
650
651 #[test]
652 fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
653 let t1 = test_table_scan_with_name("t1")?;
654 let t2 = test_table_scan_with_name("t2")?;
655 let t3 = test_table_scan_with_name("t3")?;
656
657 let plan = LogicalPlanBuilder::from(t1)
659 .join(
660 t3,
661 JoinType::Inner,
662 (vec!["t1.a"], vec!["t3.a"]),
663 Some(col("t1.a").gt(lit(20u32))),
664 )?
665 .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
666 .filter(col("t1.a").gt(lit(15u32)))?
667 .build()?;
668
669 assert_optimized_plan_equal!(
670 plan,
671 @ r"
672 Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
673 Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
674 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
675 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
676 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
677 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
678 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
679 "
680 )
681 }
682
683 #[test]
684 fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
700 let t1 = test_table_scan_with_name("t1")?;
701 let t2 = test_table_scan_with_name("t2")?;
702 let t3 = test_table_scan_with_name("t3")?;
703
704 let plan = LogicalPlanBuilder::from(t1)
706 .cross_join(t2)?
707 .cross_join(t3)?
708 .filter(binary_expr(
709 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
710 And,
711 binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
712 ))?
713 .build()?;
714
715 assert_optimized_plan_equal!(
716 plan,
717 @ r"
718 Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
719 Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
720 Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
721 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
722 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
723 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
724 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
725 "
726 )
727 }
728
729 #[test]
730 fn eliminate_cross_join_multi_tables() -> Result<()> {
731 let t1 = test_table_scan_with_name("t1")?;
732 let t2 = test_table_scan_with_name("t2")?;
733 let t3 = test_table_scan_with_name("t3")?;
734 let t4 = test_table_scan_with_name("t4")?;
735
736 let plan1 = LogicalPlanBuilder::from(t1)
738 .cross_join(t2)?
739 .filter(binary_expr(
740 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
741 Or,
742 binary_expr(
743 col("t1.a").eq(col("t2.a")),
744 And,
745 col("t2.c").eq(lit(688u32)),
746 ),
747 ))?
748 .build()?;
749
750 let plan2 = LogicalPlanBuilder::from(t3)
751 .cross_join(t4)?
752 .filter(binary_expr(
753 binary_expr(
754 binary_expr(
755 col("t3.a").eq(col("t4.a")),
756 And,
757 col("t4.c").lt(lit(15u32)),
758 ),
759 Or,
760 binary_expr(
761 col("t3.a").eq(col("t4.a")),
762 And,
763 col("t3.c").eq(lit(688u32)),
764 ),
765 ),
766 Or,
767 binary_expr(
768 col("t3.a").eq(col("t4.a")),
769 And,
770 col("t3.b").eq(col("t4.b")),
771 ),
772 ))?
773 .build()?;
774
775 let plan = LogicalPlanBuilder::from(plan1)
776 .cross_join(plan2)?
777 .filter(binary_expr(
778 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
779 Or,
780 binary_expr(
781 col("t3.a").eq(col("t1.a")),
782 And,
783 col("t4.c").eq(lit(688u32)),
784 ),
785 ))?
786 .build()?;
787
788 assert_optimized_plan_equal!(
789 plan,
790 @ r"
791 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
792 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
793 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
794 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
795 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
796 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
797 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
798 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
799 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
800 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
801 "
802 )
803 }
804
805 #[test]
806 fn eliminate_cross_join_multi_tables_1() -> Result<()> {
807 let t1 = test_table_scan_with_name("t1")?;
808 let t2 = test_table_scan_with_name("t2")?;
809 let t3 = test_table_scan_with_name("t3")?;
810 let t4 = test_table_scan_with_name("t4")?;
811
812 let plan1 = LogicalPlanBuilder::from(t1)
814 .cross_join(t2)?
815 .filter(binary_expr(
816 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
817 Or,
818 binary_expr(
819 col("t1.a").eq(col("t2.a")),
820 And,
821 col("t2.c").eq(lit(688u32)),
822 ),
823 ))?
824 .build()?;
825
826 let plan2 = LogicalPlanBuilder::from(t3)
828 .cross_join(t4)?
829 .filter(binary_expr(
830 binary_expr(
831 binary_expr(
832 col("t3.a").eq(col("t4.a")),
833 And,
834 col("t4.c").lt(lit(15u32)),
835 ),
836 Or,
837 binary_expr(
838 col("t3.a").eq(col("t4.a")),
839 And,
840 col("t3.c").eq(lit(688u32)),
841 ),
842 ),
843 Or,
844 binary_expr(
845 col("t3.a").eq(col("t4.a")),
846 And,
847 col("t3.b").eq(col("t4.b")),
848 ),
849 ))?
850 .build()?;
851
852 let plan = LogicalPlanBuilder::from(plan1)
854 .cross_join(plan2)?
855 .filter(binary_expr(
856 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
857 Or,
858 binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
859 ))?
860 .build()?;
861
862 assert_optimized_plan_equal!(
863 plan,
864 @ r"
865 Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
866 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
867 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
868 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
869 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
870 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
871 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
872 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
873 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
874 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
875 "
876 )
877 }
878
879 #[test]
880 fn eliminate_cross_join_multi_tables_2() -> Result<()> {
881 let t1 = test_table_scan_with_name("t1")?;
882 let t2 = test_table_scan_with_name("t2")?;
883 let t3 = test_table_scan_with_name("t3")?;
884 let t4 = test_table_scan_with_name("t4")?;
885
886 let plan1 = LogicalPlanBuilder::from(t1)
888 .cross_join(t2)?
889 .filter(binary_expr(
890 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
891 Or,
892 binary_expr(
893 col("t1.a").eq(col("t2.a")),
894 And,
895 col("t2.c").eq(lit(688u32)),
896 ),
897 ))?
898 .build()?;
899
900 let plan2 = LogicalPlanBuilder::from(t3)
902 .cross_join(t4)?
903 .filter(binary_expr(
904 binary_expr(
905 binary_expr(
906 col("t3.a").eq(col("t4.a")),
907 And,
908 col("t4.c").lt(lit(15u32)),
909 ),
910 Or,
911 binary_expr(
912 col("t3.a").eq(col("t4.a")),
913 And,
914 col("t3.c").eq(lit(688u32)),
915 ),
916 ),
917 Or,
918 binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
919 ))?
920 .build()?;
921
922 let plan = LogicalPlanBuilder::from(plan1)
924 .cross_join(plan2)?
925 .filter(binary_expr(
926 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
927 Or,
928 binary_expr(
929 col("t3.a").eq(col("t1.a")),
930 And,
931 col("t4.c").eq(lit(688u32)),
932 ),
933 ))?
934 .build()?;
935
936 assert_optimized_plan_equal!(
937 plan,
938 @ r"
939 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
940 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
941 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
942 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
943 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
944 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
945 Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
946 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
947 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
948 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
949 "
950 )
951 }
952
953 #[test]
954 fn eliminate_cross_join_multi_tables_3() -> Result<()> {
955 let t1 = test_table_scan_with_name("t1")?;
956 let t2 = test_table_scan_with_name("t2")?;
957 let t3 = test_table_scan_with_name("t3")?;
958 let t4 = test_table_scan_with_name("t4")?;
959
960 let plan1 = LogicalPlanBuilder::from(t1)
962 .cross_join(t2)?
963 .filter(binary_expr(
964 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
965 Or,
966 binary_expr(
967 col("t1.a").eq(col("t2.a")),
968 And,
969 col("t2.c").eq(lit(688u32)),
970 ),
971 ))?
972 .build()?;
973
974 let plan2 = LogicalPlanBuilder::from(t3)
976 .cross_join(t4)?
977 .filter(binary_expr(
978 binary_expr(
979 binary_expr(
980 col("t3.a").eq(col("t4.a")),
981 And,
982 col("t4.c").lt(lit(15u32)),
983 ),
984 Or,
985 binary_expr(
986 col("t3.a").eq(col("t4.a")),
987 And,
988 col("t3.c").eq(lit(688u32)),
989 ),
990 ),
991 Or,
992 binary_expr(
993 col("t3.a").eq(col("t4.a")),
994 And,
995 col("t3.b").eq(col("t4.b")),
996 ),
997 ))?
998 .build()?;
999
1000 let plan = LogicalPlanBuilder::from(plan1)
1002 .cross_join(plan2)?
1003 .filter(binary_expr(
1004 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
1005 Or,
1006 binary_expr(
1007 col("t3.a").eq(col("t1.a")),
1008 And,
1009 col("t4.c").eq(lit(688u32)),
1010 ),
1011 ))?
1012 .build()?;
1013
1014 assert_optimized_plan_equal!(
1015 plan,
1016 @ r"
1017 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1018 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1019 Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1020 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1021 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1022 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1023 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1024 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1025 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1026 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1027 "
1028 )
1029 }
1030
1031 #[test]
1032 fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1033 let t1 = test_table_scan_with_name("t1")?;
1034 let t2 = test_table_scan_with_name("t2")?;
1035 let t3 = test_table_scan_with_name("t3")?;
1036 let t4 = test_table_scan_with_name("t4")?;
1037
1038 let plan1 = LogicalPlanBuilder::from(t1)
1041 .cross_join(t2)?
1042 .filter(binary_expr(
1043 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1044 And,
1045 binary_expr(
1046 col("t1.a").eq(col("t2.a")),
1047 And,
1048 col("t2.c").eq(lit(688u32)),
1049 ),
1050 ))?
1051 .build()?;
1052
1053 let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1055
1056 let plan = LogicalPlanBuilder::from(plan1)
1062 .cross_join(plan2)?
1063 .filter(binary_expr(
1064 binary_expr(
1065 binary_expr(
1066 col("t3.a").eq(col("t1.a")),
1067 And,
1068 col("t4.c").lt(lit(15u32)),
1069 ),
1070 Or,
1071 binary_expr(
1072 col("t3.a").eq(col("t1.a")),
1073 And,
1074 col("t4.c").eq(lit(688u32)),
1075 ),
1076 ),
1077 And,
1078 binary_expr(
1079 binary_expr(
1080 binary_expr(
1081 col("t3.a").eq(col("t4.a")),
1082 And,
1083 col("t4.c").lt(lit(15u32)),
1084 ),
1085 Or,
1086 binary_expr(
1087 col("t3.a").eq(col("t4.a")),
1088 And,
1089 col("t3.c").eq(lit(688u32)),
1090 ),
1091 ),
1092 Or,
1093 binary_expr(
1094 col("t3.a").eq(col("t4.a")),
1095 And,
1096 col("t3.b").eq(col("t4.b")),
1097 ),
1098 ),
1099 ))?
1100 .build()?;
1101
1102 assert_optimized_plan_equal!(
1103 plan,
1104 @ r"
1105 Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1106 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1107 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1108 Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1109 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1110 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1111 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1112 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1113 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1114 "
1115 )
1116 }
1117
1118 #[test]
1119 fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1120 let t1 = test_table_scan_with_name("t1")?;
1121 let t2 = test_table_scan_with_name("t2")?;
1122 let t3 = test_table_scan_with_name("t3")?;
1123 let t4 = test_table_scan_with_name("t4")?;
1124
1125 let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1127
1128 let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1130
1131 let plan = LogicalPlanBuilder::from(plan1)
1139 .cross_join(plan2)?
1140 .filter(binary_expr(
1141 binary_expr(
1142 binary_expr(
1143 binary_expr(
1144 col("t3.a").eq(col("t1.a")),
1145 And,
1146 col("t4.c").lt(lit(15u32)),
1147 ),
1148 Or,
1149 binary_expr(
1150 col("t3.a").eq(col("t1.a")),
1151 And,
1152 col("t4.c").eq(lit(688u32)),
1153 ),
1154 ),
1155 And,
1156 binary_expr(
1157 binary_expr(
1158 binary_expr(
1159 col("t3.a").eq(col("t4.a")),
1160 And,
1161 col("t4.c").lt(lit(15u32)),
1162 ),
1163 Or,
1164 binary_expr(
1165 col("t3.a").eq(col("t4.a")),
1166 And,
1167 col("t3.c").eq(lit(688u32)),
1168 ),
1169 ),
1170 Or,
1171 binary_expr(
1172 col("t3.a").eq(col("t4.a")),
1173 And,
1174 col("t3.b").eq(col("t4.b")),
1175 ),
1176 ),
1177 ),
1178 And,
1179 binary_expr(
1180 binary_expr(
1181 col("t1.a").eq(col("t2.a")),
1182 Or,
1183 col("t2.c").lt(lit(15u32)),
1184 ),
1185 And,
1186 binary_expr(
1187 col("t1.a").eq(col("t2.a")),
1188 And,
1189 col("t2.c").eq(lit(688u32)),
1190 ),
1191 ),
1192 ))?
1193 .build()?;
1194
1195 assert_optimized_plan_equal!(
1196 plan,
1197 @ r"
1198 Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1199 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1200 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1201 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1202 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1203 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1204 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1205 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1206 "
1207 )
1208 }
1209
1210 #[test]
1211 fn eliminate_cross_join_with_expr_and() -> Result<()> {
1212 let t1 = test_table_scan_with_name("t1")?;
1213 let t2 = test_table_scan_with_name("t2")?;
1214
1215 let plan = LogicalPlanBuilder::from(t1)
1217 .cross_join(t2)?
1218 .filter(binary_expr(
1219 (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1220 And,
1221 col("t2.c").lt(lit(20u32)),
1222 ))?
1223 .build()?;
1224
1225 assert_optimized_plan_equal!(
1226 plan,
1227 @ r"
1228 Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1229 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1230 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1231 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1232 "
1233 )
1234 }
1235
1236 #[test]
1237 fn eliminate_cross_with_expr_or() -> Result<()> {
1238 let t1 = test_table_scan_with_name("t1")?;
1239 let t2 = test_table_scan_with_name("t2")?;
1240
1241 let plan = LogicalPlanBuilder::from(t1)
1244 .cross_join(t2)?
1245 .filter(binary_expr(
1246 (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1247 Or,
1248 col("t2.b").eq(col("t1.a")),
1249 ))?
1250 .build()?;
1251
1252 assert_optimized_plan_equal!(
1253 plan,
1254 @ r"
1255 Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1256 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1257 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1258 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1259 "
1260 )
1261 }
1262
1263 #[test]
1264 fn eliminate_cross_with_common_expr_and() -> Result<()> {
1265 let t1 = test_table_scan_with_name("t1")?;
1266 let t2 = test_table_scan_with_name("t2")?;
1267
1268 let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1270 let plan = LogicalPlanBuilder::from(t1)
1271 .cross_join(t2)?
1272 .filter(binary_expr(
1273 binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1274 And,
1275 binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1276 ))?
1277 .build()?;
1278
1279 assert_optimized_plan_equal!(
1280 plan,
1281 @ r"
1282 Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1283 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1284 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1285 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1286 "
1287 )
1288 }
1289
1290 #[test]
1291 fn eliminate_cross_with_common_expr_or() -> Result<()> {
1292 let t1 = test_table_scan_with_name("t1")?;
1293 let t2 = test_table_scan_with_name("t2")?;
1294
1295 let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1297 let plan = LogicalPlanBuilder::from(t1)
1298 .cross_join(t2)?
1299 .filter(binary_expr(
1300 binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1301 Or,
1302 binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1303 ))?
1304 .build()?;
1305
1306 assert_optimized_plan_equal!(
1307 plan,
1308 @ r"
1309 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1310 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1311 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1312 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1313 "
1314 )
1315 }
1316
1317 #[test]
1318 fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1319 let t1 = test_table_scan_with_name("t1")?;
1320 let t2 = test_table_scan_with_name("t2")?;
1321 let t3 = test_table_scan_with_name("t3")?;
1322
1323 let plan = LogicalPlanBuilder::from(t1)
1325 .cross_join(t2)?
1326 .cross_join(t3)?
1327 .filter(binary_expr(
1328 binary_expr(
1329 (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1330 And,
1331 col("t3.c").lt(lit(15u32)),
1332 ),
1333 And,
1334 binary_expr(
1335 (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1336 And,
1337 col("t3.b").lt(lit(15u32)),
1338 ),
1339 ))?
1340 .build()?;
1341
1342 assert_optimized_plan_equal!(
1343 plan,
1344 @ r"
1345 Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1346 Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1347 Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1348 Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1349 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1350 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1351 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1352 "
1353 )
1354 }
1355
1356 #[test]
1357 fn preserve_null_equality_setting() -> Result<()> {
1358 let t1 = test_table_scan_with_name("t1")?;
1359 let t2 = test_table_scan_with_name("t2")?;
1360
1361 let join_schema = Arc::new(build_join_schema(
1363 t1.schema(),
1364 t2.schema(),
1365 &JoinType::Inner,
1366 )?);
1367
1368 let inner_join = LogicalPlan::Join(Join {
1369 left: Arc::new(t1),
1370 right: Arc::new(t2),
1371 join_type: JoinType::Inner,
1372 join_constraint: JoinConstraint::On,
1373 on: vec![],
1374 filter: None,
1375 schema: join_schema,
1376 null_equality: NullEquality::NullEqualsNull, null_aware: false,
1378 });
1379
1380 let plan = LogicalPlanBuilder::from(inner_join)
1382 .filter(binary_expr(
1383 col("t1.a").eq(col("t2.a")),
1384 And,
1385 col("t2.c").lt(lit(20u32)),
1386 ))?
1387 .build()?;
1388
1389 let rule = EliminateCrossJoin::new();
1390 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1391
1392 fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
1394 match plan {
1395 LogicalPlan::Join(join) => {
1396 if join.null_equality == NullEquality::NullEqualsNothing {
1398 return false;
1399 }
1400 plan.inputs()
1402 .iter()
1403 .all(|input| check_null_equality_preserved(input))
1404 }
1405 _ => {
1406 plan.inputs()
1408 .iter()
1409 .all(|input| check_null_equality_preserved(input))
1410 }
1411 }
1412 }
1413
1414 assert!(
1415 check_null_equality_preserved(&optimized_plan),
1416 "null_equality setting should be preserved after optimization"
1417 );
1418
1419 Ok(())
1420 }
1421}