1use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27 Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{ExprSchemable, Operator, and, build_join_schema};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36 #[expect(missing_docs)]
37 pub fn new() -> Self {
38 Self {}
39 }
40}
41
42impl OptimizerRule for EliminateCrossJoin {
78 fn supports_rewrite(&self) -> bool {
79 true
80 }
81
82 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83 fn rewrite(
84 &self,
85 plan: LogicalPlan,
86 config: &dyn OptimizerConfig,
87 ) -> Result<Transformed<LogicalPlan>> {
88 let plan_schema = Arc::clone(plan.schema());
89 let mut possible_join_keys = JoinKeySet::new();
90 let mut all_inputs: Vec<LogicalPlan> = vec![];
91 let mut all_filters: Vec<Expr> = vec![];
92 let mut null_equality = NullEquality::NullEqualsNothing;
93
94 let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
95 let rewritable = matches!(
98 filter.input.as_ref(),
99 LogicalPlan::Join(Join {
100 join_type: JoinType::Inner,
101 ..
102 })
103 );
104
105 if !rewritable {
106 return rewrite_children(self, LogicalPlan::Filter(filter), config);
108 }
109
110 if !can_flatten_join_inputs(&filter.input) {
111 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
112 }
113
114 let Filter {
115 input, predicate, ..
116 } = filter;
117
118 if let LogicalPlan::Join(join) = input.as_ref() {
120 null_equality = join.null_equality;
121 }
122
123 flatten_join_inputs(
124 Arc::unwrap_or_clone(input),
125 &mut possible_join_keys,
126 &mut all_inputs,
127 &mut all_filters,
128 )?;
129
130 extract_possible_join_keys(&predicate, &mut possible_join_keys);
131 Some(predicate)
132 } else {
133 match plan {
134 LogicalPlan::Join(Join {
135 join_type: JoinType::Inner,
136 null_equality: original_null_equality,
137 ..
138 }) => {
139 if !can_flatten_join_inputs(&plan) {
140 return Ok(Transformed::no(plan));
141 }
142 flatten_join_inputs(
143 plan,
144 &mut possible_join_keys,
145 &mut all_inputs,
146 &mut all_filters,
147 )?;
148 null_equality = original_null_equality;
149 None
150 }
151 _ => {
152 return rewrite_children(self, plan, config);
154 }
155 }
156 };
157
158 let mut all_join_keys = JoinKeySet::new();
160 let mut left = all_inputs.remove(0);
161 while !all_inputs.is_empty() {
162 left = find_inner_join(
163 left,
164 &mut all_inputs,
165 &possible_join_keys,
166 &mut all_join_keys,
167 null_equality,
168 )?;
169 }
170
171 left = rewrite_children(self, left, config)?.data;
172
173 if &plan_schema != left.schema() {
174 left = LogicalPlan::Projection(Projection::new_from_schema(
175 Arc::new(left),
176 Arc::clone(&plan_schema),
177 ));
178 }
179
180 if !all_filters.is_empty() {
181 let first = all_filters.swap_remove(0);
183 let predicate = all_filters.into_iter().fold(first, and);
184 left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
185 }
186
187 let Some(predicate) = parent_predicate else {
188 return Ok(Transformed::yes(left));
189 };
190
191 if all_join_keys.is_empty() {
193 Filter::try_new(predicate, Arc::new(left))
194 .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
195 } else {
196 match remove_join_expressions(predicate, &all_join_keys) {
198 Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
199 .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
200 _ => Ok(Transformed::yes(left)),
201 }
202 }
203 }
204
205 fn name(&self) -> &str {
206 "eliminate_cross_join"
207 }
208}
209
210fn rewrite_children(
211 optimizer: &impl OptimizerRule,
212 plan: LogicalPlan,
213 config: &dyn OptimizerConfig,
214) -> Result<Transformed<LogicalPlan>> {
215 let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
216
217 if transformed_plan.transformed {
219 transformed_plan.map_data(|plan| plan.recompute_schema())
220 } else {
221 Ok(transformed_plan)
222 }
223}
224
225fn flatten_join_inputs(
232 plan: LogicalPlan,
233 possible_join_keys: &mut JoinKeySet,
234 all_inputs: &mut Vec<LogicalPlan>,
235 all_filters: &mut Vec<Expr>,
236) -> Result<()> {
237 match plan {
238 LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
239 if let Some(filter) = join.filter {
240 all_filters.push(filter);
241 }
242 possible_join_keys.insert_all_owned(join.on);
243 flatten_join_inputs(
244 Arc::unwrap_or_clone(join.left),
245 possible_join_keys,
246 all_inputs,
247 all_filters,
248 )?;
249 flatten_join_inputs(
250 Arc::unwrap_or_clone(join.right),
251 possible_join_keys,
252 all_inputs,
253 all_filters,
254 )?;
255 }
256 _ => {
257 all_inputs.push(plan);
258 }
259 };
260 Ok(())
261}
262
263fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
268 match plan {
270 LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
271 _ => return false,
272 };
273
274 for child in plan.inputs() {
275 if let LogicalPlan::Join(Join {
276 join_type: JoinType::Inner,
277 ..
278 }) = child
279 && !can_flatten_join_inputs(child)
280 {
281 return false;
282 }
283 }
284 true
285}
286
287fn find_inner_join(
300 left_input: LogicalPlan,
301 rights: &mut Vec<LogicalPlan>,
302 possible_join_keys: &JoinKeySet,
303 all_join_keys: &mut JoinKeySet,
304 null_equality: NullEquality,
305) -> Result<LogicalPlan> {
306 for (i, right_input) in rights.iter().enumerate() {
307 let mut join_keys = vec![];
308
309 for (l, r) in possible_join_keys.iter() {
310 let key_pair = find_valid_equijoin_key_pair(
311 l,
312 r,
313 left_input.schema(),
314 right_input.schema(),
315 )?;
316
317 if let Some((valid_l, valid_r)) = key_pair
319 && can_hash(&valid_l.get_type(left_input.schema())?)
320 {
321 join_keys.push((valid_l, valid_r));
322 }
323 }
324
325 if !join_keys.is_empty() {
327 all_join_keys.insert_all(join_keys.iter());
328 let right_input = rights.remove(i);
329 let join_schema = Arc::new(build_join_schema(
330 left_input.schema(),
331 right_input.schema(),
332 &JoinType::Inner,
333 )?);
334
335 return Ok(LogicalPlan::Join(Join {
336 left: Arc::new(left_input),
337 right: Arc::new(right_input),
338 join_type: JoinType::Inner,
339 join_constraint: JoinConstraint::On,
340 on: join_keys,
341 filter: None,
342 schema: join_schema,
343 null_equality,
344 null_aware: false,
345 }));
346 }
347 }
348
349 let right = rights.remove(0);
352 let join_schema = Arc::new(build_join_schema(
353 left_input.schema(),
354 right.schema(),
355 &JoinType::Inner,
356 )?);
357
358 Ok(LogicalPlan::Join(Join {
359 left: Arc::new(left_input),
360 right: Arc::new(right),
361 schema: join_schema,
362 on: vec![],
363 filter: None,
364 join_type: JoinType::Inner,
365 join_constraint: JoinConstraint::On,
366 null_equality,
367 null_aware: false,
368 }))
369}
370
371fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
373 if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
374 match op {
375 Operator::Eq => {
376 join_keys.insert(left, right);
378 }
379 Operator::And => {
380 extract_possible_join_keys(left, join_keys);
381 extract_possible_join_keys(right, join_keys)
382 }
383 Operator::Or => {
385 let mut left_join_keys = JoinKeySet::new();
386 let mut right_join_keys = JoinKeySet::new();
387
388 extract_possible_join_keys(left, &mut left_join_keys);
389 extract_possible_join_keys(right, &mut right_join_keys);
390
391 join_keys.insert_intersection(&left_join_keys, &right_join_keys)
392 }
393 _ => (),
394 };
395 }
396}
397
398fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
404 match expr {
405 Expr::BinaryExpr(BinaryExpr {
406 left,
407 op: Operator::Eq,
408 right,
409 }) if join_keys.contains(&left, &right) => {
410 None
412 }
413 Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
415 let l = remove_join_expressions(*left, join_keys);
416 let r = remove_join_expressions(*right, join_keys);
417 match (l, r) {
418 (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
419 Box::new(ll),
420 op,
421 Box::new(rr),
422 ))),
423 (Some(ll), _) => Some(ll),
424 (_, Some(rr)) => Some(rr),
425 _ => None,
426 }
427 }
428 Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
429 let l = remove_join_expressions(*left, join_keys);
430 let r = remove_join_expressions(*right, join_keys);
431 match (l, r) {
432 (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
433 Box::new(ll),
434 op,
435 Box::new(rr),
436 ))),
437 _ => None,
440 }
441 }
442 _ => Some(expr),
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use crate::optimizer::OptimizerContext;
450 use crate::test::*;
451
452 use datafusion_expr::{
453 Operator::{And, Or},
454 binary_expr, col, lit,
455 logical_plan::builder::LogicalPlanBuilder,
456 };
457 use insta::assert_snapshot;
458
459 macro_rules! assert_optimized_plan_equal {
460 (
461 $plan:expr,
462 @ $expected:literal $(,)?
463 ) => {{
464 let starting_schema = Arc::clone($plan.schema());
465 let rule = EliminateCrossJoin::new();
466 let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
467 let formatted_plan = optimized_plan.display_indent_schema();
468 assert!(is_plan_transformed, "failed to optimize plan");
470 assert_eq!(&starting_schema, optimized_plan.schema());
472 assert_snapshot!(
473 formatted_plan,
474 @ $expected,
475 );
476
477 Ok(())
478 }};
479 }
480
481 #[test]
482 fn eliminate_cross_with_simple_and() -> Result<()> {
483 let t1 = test_table_scan_with_name("t1")?;
484 let t2 = test_table_scan_with_name("t2")?;
485
486 let plan = LogicalPlanBuilder::from(t1)
488 .cross_join(t2)?
489 .filter(binary_expr(
490 col("t1.a").eq(col("t2.a")),
491 And,
492 col("t2.c").lt(lit(20u32)),
493 ))?
494 .build()?;
495
496 assert_optimized_plan_equal!(
497 plan,
498 @ r"
499 Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
500 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
501 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
502 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
503 "
504 )
505 }
506
507 #[test]
508 fn eliminate_cross_with_simple_or() -> Result<()> {
509 let t1 = test_table_scan_with_name("t1")?;
510 let t2 = test_table_scan_with_name("t2")?;
511
512 let plan = LogicalPlanBuilder::from(t1)
515 .cross_join(t2)?
516 .filter(binary_expr(
517 col("t1.a").eq(col("t2.a")),
518 Or,
519 col("t2.b").eq(col("t1.a")),
520 ))?
521 .build()?;
522
523 assert_optimized_plan_equal!(
524 plan,
525 @ r"
526 Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
527 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
528 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
529 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
530 "
531 )
532 }
533
534 #[test]
535 fn eliminate_cross_with_and() -> Result<()> {
536 let t1 = test_table_scan_with_name("t1")?;
537 let t2 = test_table_scan_with_name("t2")?;
538
539 let plan = LogicalPlanBuilder::from(t1)
541 .cross_join(t2)?
542 .filter(binary_expr(
543 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
544 And,
545 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
546 ))?
547 .build()?;
548
549 assert_optimized_plan_equal!(
550 plan,
551 @ r"
552 Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
553 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
554 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
555 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
556 "
557 )
558 }
559
560 #[test]
561 fn eliminate_cross_with_or() -> Result<()> {
562 let t1 = test_table_scan_with_name("t1")?;
563 let t2 = test_table_scan_with_name("t2")?;
564
565 let plan = LogicalPlanBuilder::from(t1)
567 .cross_join(t2)?
568 .filter(binary_expr(
569 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
570 Or,
571 binary_expr(
572 col("t1.a").eq(col("t2.a")),
573 And,
574 col("t2.c").eq(lit(688u32)),
575 ),
576 ))?
577 .build()?;
578
579 assert_optimized_plan_equal!(
580 plan,
581 @ r"
582 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
583 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
584 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
585 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
586 "
587 )
588 }
589
590 #[test]
591 fn eliminate_cross_not_possible_simple() -> Result<()> {
592 let t1 = test_table_scan_with_name("t1")?;
593 let t2 = test_table_scan_with_name("t2")?;
594
595 let plan = LogicalPlanBuilder::from(t1)
597 .cross_join(t2)?
598 .filter(binary_expr(
599 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
600 Or,
601 binary_expr(
602 col("t1.b").eq(col("t2.b")),
603 And,
604 col("t2.c").eq(lit(688u32)),
605 ),
606 ))?
607 .build()?;
608
609 assert_optimized_plan_equal!(
610 plan,
611 @ r"
612 Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
613 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
614 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
615 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
616 "
617 )
618 }
619
620 #[test]
621 fn eliminate_cross_not_possible() -> Result<()> {
622 let t1 = test_table_scan_with_name("t1")?;
623 let t2 = test_table_scan_with_name("t2")?;
624
625 let plan = LogicalPlanBuilder::from(t1)
627 .cross_join(t2)?
628 .filter(binary_expr(
629 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
630 Or,
631 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
632 ))?
633 .build()?;
634
635 assert_optimized_plan_equal!(
636 plan,
637 @ r"
638 Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
639 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
640 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
641 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
642 "
643 )
644 }
645
646 #[test]
647 fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
648 let t1 = test_table_scan_with_name("t1")?;
649 let t2 = test_table_scan_with_name("t2")?;
650 let t3 = test_table_scan_with_name("t3")?;
651
652 let plan = LogicalPlanBuilder::from(t1)
654 .join(
655 t3,
656 JoinType::Inner,
657 (vec!["t1.a"], vec!["t3.a"]),
658 Some(col("t1.a").gt(lit(20u32))),
659 )?
660 .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
661 .filter(col("t1.a").gt(lit(15u32)))?
662 .build()?;
663
664 assert_optimized_plan_equal!(
665 plan,
666 @ r"
667 Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
668 Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
669 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
670 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
671 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
672 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
673 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
674 "
675 )
676 }
677
678 #[test]
679 fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
695 let t1 = test_table_scan_with_name("t1")?;
696 let t2 = test_table_scan_with_name("t2")?;
697 let t3 = test_table_scan_with_name("t3")?;
698
699 let plan = LogicalPlanBuilder::from(t1)
701 .cross_join(t2)?
702 .cross_join(t3)?
703 .filter(binary_expr(
704 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
705 And,
706 binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
707 ))?
708 .build()?;
709
710 assert_optimized_plan_equal!(
711 plan,
712 @ r"
713 Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
714 Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
715 Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
716 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
717 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
718 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
719 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
720 "
721 )
722 }
723
724 #[test]
725 fn eliminate_cross_join_multi_tables() -> Result<()> {
726 let t1 = test_table_scan_with_name("t1")?;
727 let t2 = test_table_scan_with_name("t2")?;
728 let t3 = test_table_scan_with_name("t3")?;
729 let t4 = test_table_scan_with_name("t4")?;
730
731 let plan1 = LogicalPlanBuilder::from(t1)
733 .cross_join(t2)?
734 .filter(binary_expr(
735 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
736 Or,
737 binary_expr(
738 col("t1.a").eq(col("t2.a")),
739 And,
740 col("t2.c").eq(lit(688u32)),
741 ),
742 ))?
743 .build()?;
744
745 let plan2 = LogicalPlanBuilder::from(t3)
746 .cross_join(t4)?
747 .filter(binary_expr(
748 binary_expr(
749 binary_expr(
750 col("t3.a").eq(col("t4.a")),
751 And,
752 col("t4.c").lt(lit(15u32)),
753 ),
754 Or,
755 binary_expr(
756 col("t3.a").eq(col("t4.a")),
757 And,
758 col("t3.c").eq(lit(688u32)),
759 ),
760 ),
761 Or,
762 binary_expr(
763 col("t3.a").eq(col("t4.a")),
764 And,
765 col("t3.b").eq(col("t4.b")),
766 ),
767 ))?
768 .build()?;
769
770 let plan = LogicalPlanBuilder::from(plan1)
771 .cross_join(plan2)?
772 .filter(binary_expr(
773 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
774 Or,
775 binary_expr(
776 col("t3.a").eq(col("t1.a")),
777 And,
778 col("t4.c").eq(lit(688u32)),
779 ),
780 ))?
781 .build()?;
782
783 assert_optimized_plan_equal!(
784 plan,
785 @ r"
786 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
787 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
788 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
789 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
790 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
791 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
792 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
793 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
794 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
795 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
796 "
797 )
798 }
799
800 #[test]
801 fn eliminate_cross_join_multi_tables_1() -> Result<()> {
802 let t1 = test_table_scan_with_name("t1")?;
803 let t2 = test_table_scan_with_name("t2")?;
804 let t3 = test_table_scan_with_name("t3")?;
805 let t4 = test_table_scan_with_name("t4")?;
806
807 let plan1 = LogicalPlanBuilder::from(t1)
809 .cross_join(t2)?
810 .filter(binary_expr(
811 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
812 Or,
813 binary_expr(
814 col("t1.a").eq(col("t2.a")),
815 And,
816 col("t2.c").eq(lit(688u32)),
817 ),
818 ))?
819 .build()?;
820
821 let plan2 = LogicalPlanBuilder::from(t3)
823 .cross_join(t4)?
824 .filter(binary_expr(
825 binary_expr(
826 binary_expr(
827 col("t3.a").eq(col("t4.a")),
828 And,
829 col("t4.c").lt(lit(15u32)),
830 ),
831 Or,
832 binary_expr(
833 col("t3.a").eq(col("t4.a")),
834 And,
835 col("t3.c").eq(lit(688u32)),
836 ),
837 ),
838 Or,
839 binary_expr(
840 col("t3.a").eq(col("t4.a")),
841 And,
842 col("t3.b").eq(col("t4.b")),
843 ),
844 ))?
845 .build()?;
846
847 let plan = LogicalPlanBuilder::from(plan1)
849 .cross_join(plan2)?
850 .filter(binary_expr(
851 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
852 Or,
853 binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
854 ))?
855 .build()?;
856
857 assert_optimized_plan_equal!(
858 plan,
859 @ r"
860 Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
861 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
862 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
863 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
864 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
865 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
866 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
867 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
868 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
869 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
870 "
871 )
872 }
873
874 #[test]
875 fn eliminate_cross_join_multi_tables_2() -> Result<()> {
876 let t1 = test_table_scan_with_name("t1")?;
877 let t2 = test_table_scan_with_name("t2")?;
878 let t3 = test_table_scan_with_name("t3")?;
879 let t4 = test_table_scan_with_name("t4")?;
880
881 let plan1 = LogicalPlanBuilder::from(t1)
883 .cross_join(t2)?
884 .filter(binary_expr(
885 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
886 Or,
887 binary_expr(
888 col("t1.a").eq(col("t2.a")),
889 And,
890 col("t2.c").eq(lit(688u32)),
891 ),
892 ))?
893 .build()?;
894
895 let plan2 = LogicalPlanBuilder::from(t3)
897 .cross_join(t4)?
898 .filter(binary_expr(
899 binary_expr(
900 binary_expr(
901 col("t3.a").eq(col("t4.a")),
902 And,
903 col("t4.c").lt(lit(15u32)),
904 ),
905 Or,
906 binary_expr(
907 col("t3.a").eq(col("t4.a")),
908 And,
909 col("t3.c").eq(lit(688u32)),
910 ),
911 ),
912 Or,
913 binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
914 ))?
915 .build()?;
916
917 let plan = LogicalPlanBuilder::from(plan1)
919 .cross_join(plan2)?
920 .filter(binary_expr(
921 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
922 Or,
923 binary_expr(
924 col("t3.a").eq(col("t1.a")),
925 And,
926 col("t4.c").eq(lit(688u32)),
927 ),
928 ))?
929 .build()?;
930
931 assert_optimized_plan_equal!(
932 plan,
933 @ r"
934 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
935 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
936 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
937 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
938 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
939 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
940 Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
941 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
942 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
943 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
944 "
945 )
946 }
947
948 #[test]
949 fn eliminate_cross_join_multi_tables_3() -> Result<()> {
950 let t1 = test_table_scan_with_name("t1")?;
951 let t2 = test_table_scan_with_name("t2")?;
952 let t3 = test_table_scan_with_name("t3")?;
953 let t4 = test_table_scan_with_name("t4")?;
954
955 let plan1 = LogicalPlanBuilder::from(t1)
957 .cross_join(t2)?
958 .filter(binary_expr(
959 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
960 Or,
961 binary_expr(
962 col("t1.a").eq(col("t2.a")),
963 And,
964 col("t2.c").eq(lit(688u32)),
965 ),
966 ))?
967 .build()?;
968
969 let plan2 = LogicalPlanBuilder::from(t3)
971 .cross_join(t4)?
972 .filter(binary_expr(
973 binary_expr(
974 binary_expr(
975 col("t3.a").eq(col("t4.a")),
976 And,
977 col("t4.c").lt(lit(15u32)),
978 ),
979 Or,
980 binary_expr(
981 col("t3.a").eq(col("t4.a")),
982 And,
983 col("t3.c").eq(lit(688u32)),
984 ),
985 ),
986 Or,
987 binary_expr(
988 col("t3.a").eq(col("t4.a")),
989 And,
990 col("t3.b").eq(col("t4.b")),
991 ),
992 ))?
993 .build()?;
994
995 let plan = LogicalPlanBuilder::from(plan1)
997 .cross_join(plan2)?
998 .filter(binary_expr(
999 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
1000 Or,
1001 binary_expr(
1002 col("t3.a").eq(col("t1.a")),
1003 And,
1004 col("t4.c").eq(lit(688u32)),
1005 ),
1006 ))?
1007 .build()?;
1008
1009 assert_optimized_plan_equal!(
1010 plan,
1011 @ r"
1012 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1013 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1014 Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1015 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1016 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1017 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1018 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1019 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1020 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1021 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1022 "
1023 )
1024 }
1025
1026 #[test]
1027 fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1028 let t1 = test_table_scan_with_name("t1")?;
1029 let t2 = test_table_scan_with_name("t2")?;
1030 let t3 = test_table_scan_with_name("t3")?;
1031 let t4 = test_table_scan_with_name("t4")?;
1032
1033 let plan1 = LogicalPlanBuilder::from(t1)
1036 .cross_join(t2)?
1037 .filter(binary_expr(
1038 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1039 And,
1040 binary_expr(
1041 col("t1.a").eq(col("t2.a")),
1042 And,
1043 col("t2.c").eq(lit(688u32)),
1044 ),
1045 ))?
1046 .build()?;
1047
1048 let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1050
1051 let plan = LogicalPlanBuilder::from(plan1)
1057 .cross_join(plan2)?
1058 .filter(binary_expr(
1059 binary_expr(
1060 binary_expr(
1061 col("t3.a").eq(col("t1.a")),
1062 And,
1063 col("t4.c").lt(lit(15u32)),
1064 ),
1065 Or,
1066 binary_expr(
1067 col("t3.a").eq(col("t1.a")),
1068 And,
1069 col("t4.c").eq(lit(688u32)),
1070 ),
1071 ),
1072 And,
1073 binary_expr(
1074 binary_expr(
1075 binary_expr(
1076 col("t3.a").eq(col("t4.a")),
1077 And,
1078 col("t4.c").lt(lit(15u32)),
1079 ),
1080 Or,
1081 binary_expr(
1082 col("t3.a").eq(col("t4.a")),
1083 And,
1084 col("t3.c").eq(lit(688u32)),
1085 ),
1086 ),
1087 Or,
1088 binary_expr(
1089 col("t3.a").eq(col("t4.a")),
1090 And,
1091 col("t3.b").eq(col("t4.b")),
1092 ),
1093 ),
1094 ))?
1095 .build()?;
1096
1097 assert_optimized_plan_equal!(
1098 plan,
1099 @ r"
1100 Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1101 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1102 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1103 Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1104 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1105 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1106 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1107 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1108 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1109 "
1110 )
1111 }
1112
1113 #[test]
1114 fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1115 let t1 = test_table_scan_with_name("t1")?;
1116 let t2 = test_table_scan_with_name("t2")?;
1117 let t3 = test_table_scan_with_name("t3")?;
1118 let t4 = test_table_scan_with_name("t4")?;
1119
1120 let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1122
1123 let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1125
1126 let plan = LogicalPlanBuilder::from(plan1)
1134 .cross_join(plan2)?
1135 .filter(binary_expr(
1136 binary_expr(
1137 binary_expr(
1138 binary_expr(
1139 col("t3.a").eq(col("t1.a")),
1140 And,
1141 col("t4.c").lt(lit(15u32)),
1142 ),
1143 Or,
1144 binary_expr(
1145 col("t3.a").eq(col("t1.a")),
1146 And,
1147 col("t4.c").eq(lit(688u32)),
1148 ),
1149 ),
1150 And,
1151 binary_expr(
1152 binary_expr(
1153 binary_expr(
1154 col("t3.a").eq(col("t4.a")),
1155 And,
1156 col("t4.c").lt(lit(15u32)),
1157 ),
1158 Or,
1159 binary_expr(
1160 col("t3.a").eq(col("t4.a")),
1161 And,
1162 col("t3.c").eq(lit(688u32)),
1163 ),
1164 ),
1165 Or,
1166 binary_expr(
1167 col("t3.a").eq(col("t4.a")),
1168 And,
1169 col("t3.b").eq(col("t4.b")),
1170 ),
1171 ),
1172 ),
1173 And,
1174 binary_expr(
1175 binary_expr(
1176 col("t1.a").eq(col("t2.a")),
1177 Or,
1178 col("t2.c").lt(lit(15u32)),
1179 ),
1180 And,
1181 binary_expr(
1182 col("t1.a").eq(col("t2.a")),
1183 And,
1184 col("t2.c").eq(lit(688u32)),
1185 ),
1186 ),
1187 ))?
1188 .build()?;
1189
1190 assert_optimized_plan_equal!(
1191 plan,
1192 @ r"
1193 Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1194 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1195 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1196 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1197 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1198 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1199 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1200 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1201 "
1202 )
1203 }
1204
1205 #[test]
1206 fn eliminate_cross_join_with_expr_and() -> Result<()> {
1207 let t1 = test_table_scan_with_name("t1")?;
1208 let t2 = test_table_scan_with_name("t2")?;
1209
1210 let plan = LogicalPlanBuilder::from(t1)
1212 .cross_join(t2)?
1213 .filter(binary_expr(
1214 (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1215 And,
1216 col("t2.c").lt(lit(20u32)),
1217 ))?
1218 .build()?;
1219
1220 assert_optimized_plan_equal!(
1221 plan,
1222 @ r"
1223 Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1224 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1225 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1226 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1227 "
1228 )
1229 }
1230
1231 #[test]
1232 fn eliminate_cross_with_expr_or() -> Result<()> {
1233 let t1 = test_table_scan_with_name("t1")?;
1234 let t2 = test_table_scan_with_name("t2")?;
1235
1236 let plan = LogicalPlanBuilder::from(t1)
1239 .cross_join(t2)?
1240 .filter(binary_expr(
1241 (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1242 Or,
1243 col("t2.b").eq(col("t1.a")),
1244 ))?
1245 .build()?;
1246
1247 assert_optimized_plan_equal!(
1248 plan,
1249 @ r"
1250 Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1251 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1252 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1253 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1254 "
1255 )
1256 }
1257
1258 #[test]
1259 fn eliminate_cross_with_common_expr_and() -> Result<()> {
1260 let t1 = test_table_scan_with_name("t1")?;
1261 let t2 = test_table_scan_with_name("t2")?;
1262
1263 let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1265 let plan = LogicalPlanBuilder::from(t1)
1266 .cross_join(t2)?
1267 .filter(binary_expr(
1268 binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1269 And,
1270 binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1271 ))?
1272 .build()?;
1273
1274 assert_optimized_plan_equal!(
1275 plan,
1276 @ r"
1277 Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1278 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1279 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1280 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1281 "
1282 )
1283 }
1284
1285 #[test]
1286 fn eliminate_cross_with_common_expr_or() -> Result<()> {
1287 let t1 = test_table_scan_with_name("t1")?;
1288 let t2 = test_table_scan_with_name("t2")?;
1289
1290 let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1292 let plan = LogicalPlanBuilder::from(t1)
1293 .cross_join(t2)?
1294 .filter(binary_expr(
1295 binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1296 Or,
1297 binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1298 ))?
1299 .build()?;
1300
1301 assert_optimized_plan_equal!(
1302 plan,
1303 @ r"
1304 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1305 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1306 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1307 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1308 "
1309 )
1310 }
1311
1312 #[test]
1313 fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1314 let t1 = test_table_scan_with_name("t1")?;
1315 let t2 = test_table_scan_with_name("t2")?;
1316 let t3 = test_table_scan_with_name("t3")?;
1317
1318 let plan = LogicalPlanBuilder::from(t1)
1320 .cross_join(t2)?
1321 .cross_join(t3)?
1322 .filter(binary_expr(
1323 binary_expr(
1324 (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1325 And,
1326 col("t3.c").lt(lit(15u32)),
1327 ),
1328 And,
1329 binary_expr(
1330 (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1331 And,
1332 col("t3.b").lt(lit(15u32)),
1333 ),
1334 ))?
1335 .build()?;
1336
1337 assert_optimized_plan_equal!(
1338 plan,
1339 @ r"
1340 Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1341 Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1342 Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1343 Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1344 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1345 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1346 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1347 "
1348 )
1349 }
1350
1351 #[test]
1352 fn preserve_null_equality_setting() -> Result<()> {
1353 let t1 = test_table_scan_with_name("t1")?;
1354 let t2 = test_table_scan_with_name("t2")?;
1355
1356 let join_schema = Arc::new(build_join_schema(
1358 t1.schema(),
1359 t2.schema(),
1360 &JoinType::Inner,
1361 )?);
1362
1363 let inner_join = LogicalPlan::Join(Join {
1364 left: Arc::new(t1),
1365 right: Arc::new(t2),
1366 join_type: JoinType::Inner,
1367 join_constraint: JoinConstraint::On,
1368 on: vec![],
1369 filter: None,
1370 schema: join_schema,
1371 null_equality: NullEquality::NullEqualsNull, null_aware: false,
1373 });
1374
1375 let plan = LogicalPlanBuilder::from(inner_join)
1377 .filter(binary_expr(
1378 col("t1.a").eq(col("t2.a")),
1379 And,
1380 col("t2.c").lt(lit(20u32)),
1381 ))?
1382 .build()?;
1383
1384 let rule = EliminateCrossJoin::new();
1385 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1386
1387 fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
1389 match plan {
1390 LogicalPlan::Join(join) => {
1391 if join.null_equality == NullEquality::NullEqualsNothing {
1393 return false;
1394 }
1395 plan.inputs()
1397 .iter()
1398 .all(|input| check_null_equality_preserved(input))
1399 }
1400 _ => {
1401 plan.inputs()
1403 .iter()
1404 .all(|input| check_null_equality_preserved(input))
1405 }
1406 }
1407 }
1408
1409 assert!(
1410 check_null_equality_preserved(&optimized_plan),
1411 "null_equality setting should be preserved after optimization"
1412 );
1413
1414 Ok(())
1415 }
1416}