1use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27 Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{ExprSchemable, Operator, and, build_join_schema};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36 #[expect(missing_docs)]
37 pub fn new() -> Self {
38 Self {}
39 }
40}
41
42impl OptimizerRule for EliminateCrossJoin {
78 fn supports_rewrite(&self) -> bool {
79 true
80 }
81
82 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83 fn rewrite(
84 &self,
85 plan: LogicalPlan,
86 config: &dyn OptimizerConfig,
87 ) -> Result<Transformed<LogicalPlan>> {
88 let plan_schema = Arc::clone(plan.schema());
89 let mut possible_join_keys = JoinKeySet::new();
90 let mut all_inputs: Vec<LogicalPlan> = vec![];
91 let mut all_filters: Vec<Expr> = vec![];
92 let mut null_equality = NullEquality::NullEqualsNothing;
93
94 let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
95 let rewritable = matches!(
98 filter.input.as_ref(),
99 LogicalPlan::Join(Join {
100 join_type: JoinType::Inner,
101 ..
102 })
103 );
104
105 if !rewritable {
106 return rewrite_children(self, LogicalPlan::Filter(filter), config);
108 }
109
110 if !can_flatten_join_inputs(&filter.input) {
111 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
112 }
113
114 let Filter {
115 input, predicate, ..
116 } = filter;
117
118 if let LogicalPlan::Join(join) = input.as_ref() {
120 null_equality = join.null_equality;
121 }
122
123 flatten_join_inputs(
124 Arc::unwrap_or_clone(input),
125 &mut possible_join_keys,
126 &mut all_inputs,
127 &mut all_filters,
128 )?;
129
130 extract_possible_join_keys(&predicate, &mut possible_join_keys);
131 Some(predicate)
132 } else {
133 match plan {
134 LogicalPlan::Join(Join {
135 join_type: JoinType::Inner,
136 null_equality: original_null_equality,
137 ..
138 }) => {
139 if !can_flatten_join_inputs(&plan) {
140 return Ok(Transformed::no(plan));
141 }
142 flatten_join_inputs(
143 plan,
144 &mut possible_join_keys,
145 &mut all_inputs,
146 &mut all_filters,
147 )?;
148 null_equality = original_null_equality;
149 None
150 }
151 _ => {
152 return rewrite_children(self, plan, config);
154 }
155 }
156 };
157
158 let mut all_join_keys = JoinKeySet::new();
160 let mut left = all_inputs.remove(0);
161 while !all_inputs.is_empty() {
162 left = find_inner_join(
163 left,
164 &mut all_inputs,
165 &possible_join_keys,
166 &mut all_join_keys,
167 null_equality,
168 )?;
169 }
170
171 left = rewrite_children(self, left, config)?.data;
172
173 if &plan_schema != left.schema() {
174 left = LogicalPlan::Projection(Projection::new_from_schema(
175 Arc::new(left),
176 Arc::clone(&plan_schema),
177 ));
178 }
179
180 if !all_filters.is_empty() {
181 let first = all_filters.swap_remove(0);
183 let predicate = all_filters.into_iter().fold(first, and);
184 left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
185 }
186
187 let Some(predicate) = parent_predicate else {
188 return Ok(Transformed::yes(left));
189 };
190
191 if all_join_keys.is_empty() {
193 Filter::try_new(predicate, Arc::new(left))
194 .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
195 } else {
196 match remove_join_expressions(predicate, &all_join_keys) {
198 Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
199 .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
200 _ => Ok(Transformed::yes(left)),
201 }
202 }
203 }
204
205 fn name(&self) -> &str {
206 "eliminate_cross_join"
207 }
208}
209
210fn rewrite_children(
211 optimizer: &impl OptimizerRule,
212 plan: LogicalPlan,
213 config: &dyn OptimizerConfig,
214) -> Result<Transformed<LogicalPlan>> {
215 let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
216
217 if transformed_plan.transformed {
219 transformed_plan.map_data(|plan| plan.recompute_schema())
220 } else {
221 Ok(transformed_plan)
222 }
223}
224
225fn flatten_join_inputs(
232 plan: LogicalPlan,
233 possible_join_keys: &mut JoinKeySet,
234 all_inputs: &mut Vec<LogicalPlan>,
235 all_filters: &mut Vec<Expr>,
236) -> Result<()> {
237 match plan {
238 LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
239 if let Some(filter) = join.filter {
240 all_filters.push(filter);
241 }
242 possible_join_keys.insert_all_owned(join.on);
243 flatten_join_inputs(
244 Arc::unwrap_or_clone(join.left),
245 possible_join_keys,
246 all_inputs,
247 all_filters,
248 )?;
249 flatten_join_inputs(
250 Arc::unwrap_or_clone(join.right),
251 possible_join_keys,
252 all_inputs,
253 all_filters,
254 )?;
255 }
256 _ => {
257 all_inputs.push(plan);
258 }
259 };
260 Ok(())
261}
262
263fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
268 match plan {
270 LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
271 _ => return false,
272 };
273
274 for child in plan.inputs() {
275 if let LogicalPlan::Join(Join {
276 join_type: JoinType::Inner,
277 ..
278 }) = child
279 && !can_flatten_join_inputs(child)
280 {
281 return false;
282 }
283 }
284 true
285}
286
287fn find_inner_join(
300 left_input: LogicalPlan,
301 rights: &mut Vec<LogicalPlan>,
302 possible_join_keys: &JoinKeySet,
303 all_join_keys: &mut JoinKeySet,
304 null_equality: NullEquality,
305) -> Result<LogicalPlan> {
306 for (i, right_input) in rights.iter().enumerate() {
307 let mut join_keys = vec![];
308
309 for (l, r) in possible_join_keys.iter() {
310 let key_pair = find_valid_equijoin_key_pair(
311 l,
312 r,
313 left_input.schema(),
314 right_input.schema(),
315 )?;
316
317 if let Some((valid_l, valid_r)) = key_pair
319 && can_hash(&valid_l.get_type(left_input.schema())?)
320 {
321 join_keys.push((valid_l, valid_r));
322 }
323 }
324
325 if !join_keys.is_empty() {
327 all_join_keys.insert_all(join_keys.iter());
328 let right_input = rights.remove(i);
329 let join_schema = Arc::new(build_join_schema(
330 left_input.schema(),
331 right_input.schema(),
332 &JoinType::Inner,
333 )?);
334
335 return Ok(LogicalPlan::Join(Join {
336 left: Arc::new(left_input),
337 right: Arc::new(right_input),
338 join_type: JoinType::Inner,
339 join_constraint: JoinConstraint::On,
340 on: join_keys,
341 filter: None,
342 schema: join_schema,
343 null_equality,
344 }));
345 }
346 }
347
348 let right = rights.remove(0);
351 let join_schema = Arc::new(build_join_schema(
352 left_input.schema(),
353 right.schema(),
354 &JoinType::Inner,
355 )?);
356
357 Ok(LogicalPlan::Join(Join {
358 left: Arc::new(left_input),
359 right: Arc::new(right),
360 schema: join_schema,
361 on: vec![],
362 filter: None,
363 join_type: JoinType::Inner,
364 join_constraint: JoinConstraint::On,
365 null_equality,
366 }))
367}
368
369fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
371 if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
372 match op {
373 Operator::Eq => {
374 join_keys.insert(left, right);
376 }
377 Operator::And => {
378 extract_possible_join_keys(left, join_keys);
379 extract_possible_join_keys(right, join_keys)
380 }
381 Operator::Or => {
383 let mut left_join_keys = JoinKeySet::new();
384 let mut right_join_keys = JoinKeySet::new();
385
386 extract_possible_join_keys(left, &mut left_join_keys);
387 extract_possible_join_keys(right, &mut right_join_keys);
388
389 join_keys.insert_intersection(&left_join_keys, &right_join_keys)
390 }
391 _ => (),
392 };
393 }
394}
395
396fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
402 match expr {
403 Expr::BinaryExpr(BinaryExpr {
404 left,
405 op: Operator::Eq,
406 right,
407 }) if join_keys.contains(&left, &right) => {
408 None
410 }
411 Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
413 let l = remove_join_expressions(*left, join_keys);
414 let r = remove_join_expressions(*right, join_keys);
415 match (l, r) {
416 (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
417 Box::new(ll),
418 op,
419 Box::new(rr),
420 ))),
421 (Some(ll), _) => Some(ll),
422 (_, Some(rr)) => Some(rr),
423 _ => None,
424 }
425 }
426 Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
427 let l = remove_join_expressions(*left, join_keys);
428 let r = remove_join_expressions(*right, join_keys);
429 match (l, r) {
430 (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
431 Box::new(ll),
432 op,
433 Box::new(rr),
434 ))),
435 _ => None,
438 }
439 }
440 _ => Some(expr),
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::optimizer::OptimizerContext;
448 use crate::test::*;
449
450 use datafusion_expr::{
451 Operator::{And, Or},
452 binary_expr, col, lit,
453 logical_plan::builder::LogicalPlanBuilder,
454 };
455 use insta::assert_snapshot;
456
457 macro_rules! assert_optimized_plan_equal {
458 (
459 $plan:expr,
460 @ $expected:literal $(,)?
461 ) => {{
462 let starting_schema = Arc::clone($plan.schema());
463 let rule = EliminateCrossJoin::new();
464 let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
465 let formatted_plan = optimized_plan.display_indent_schema();
466 assert!(is_plan_transformed, "failed to optimize plan");
468 assert_eq!(&starting_schema, optimized_plan.schema());
470 assert_snapshot!(
471 formatted_plan,
472 @ $expected,
473 );
474
475 Ok(())
476 }};
477 }
478
479 #[test]
480 fn eliminate_cross_with_simple_and() -> Result<()> {
481 let t1 = test_table_scan_with_name("t1")?;
482 let t2 = test_table_scan_with_name("t2")?;
483
484 let plan = LogicalPlanBuilder::from(t1)
486 .cross_join(t2)?
487 .filter(binary_expr(
488 col("t1.a").eq(col("t2.a")),
489 And,
490 col("t2.c").lt(lit(20u32)),
491 ))?
492 .build()?;
493
494 assert_optimized_plan_equal!(
495 plan,
496 @ r"
497 Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
498 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
499 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
500 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
501 "
502 )
503 }
504
505 #[test]
506 fn eliminate_cross_with_simple_or() -> Result<()> {
507 let t1 = test_table_scan_with_name("t1")?;
508 let t2 = test_table_scan_with_name("t2")?;
509
510 let plan = LogicalPlanBuilder::from(t1)
513 .cross_join(t2)?
514 .filter(binary_expr(
515 col("t1.a").eq(col("t2.a")),
516 Or,
517 col("t2.b").eq(col("t1.a")),
518 ))?
519 .build()?;
520
521 assert_optimized_plan_equal!(
522 plan,
523 @ r"
524 Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
525 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
526 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
527 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
528 "
529 )
530 }
531
532 #[test]
533 fn eliminate_cross_with_and() -> Result<()> {
534 let t1 = test_table_scan_with_name("t1")?;
535 let t2 = test_table_scan_with_name("t2")?;
536
537 let plan = LogicalPlanBuilder::from(t1)
539 .cross_join(t2)?
540 .filter(binary_expr(
541 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
542 And,
543 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
544 ))?
545 .build()?;
546
547 assert_optimized_plan_equal!(
548 plan,
549 @ r"
550 Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
551 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
552 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
553 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
554 "
555 )
556 }
557
558 #[test]
559 fn eliminate_cross_with_or() -> Result<()> {
560 let t1 = test_table_scan_with_name("t1")?;
561 let t2 = test_table_scan_with_name("t2")?;
562
563 let plan = LogicalPlanBuilder::from(t1)
565 .cross_join(t2)?
566 .filter(binary_expr(
567 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
568 Or,
569 binary_expr(
570 col("t1.a").eq(col("t2.a")),
571 And,
572 col("t2.c").eq(lit(688u32)),
573 ),
574 ))?
575 .build()?;
576
577 assert_optimized_plan_equal!(
578 plan,
579 @ r"
580 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
581 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
582 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
583 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
584 "
585 )
586 }
587
588 #[test]
589 fn eliminate_cross_not_possible_simple() -> Result<()> {
590 let t1 = test_table_scan_with_name("t1")?;
591 let t2 = test_table_scan_with_name("t2")?;
592
593 let plan = LogicalPlanBuilder::from(t1)
595 .cross_join(t2)?
596 .filter(binary_expr(
597 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
598 Or,
599 binary_expr(
600 col("t1.b").eq(col("t2.b")),
601 And,
602 col("t2.c").eq(lit(688u32)),
603 ),
604 ))?
605 .build()?;
606
607 assert_optimized_plan_equal!(
608 plan,
609 @ r"
610 Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
611 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
612 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
613 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
614 "
615 )
616 }
617
618 #[test]
619 fn eliminate_cross_not_possible() -> Result<()> {
620 let t1 = test_table_scan_with_name("t1")?;
621 let t2 = test_table_scan_with_name("t2")?;
622
623 let plan = LogicalPlanBuilder::from(t1)
625 .cross_join(t2)?
626 .filter(binary_expr(
627 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
628 Or,
629 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
630 ))?
631 .build()?;
632
633 assert_optimized_plan_equal!(
634 plan,
635 @ r"
636 Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
637 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
638 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
639 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
640 "
641 )
642 }
643
644 #[test]
645 fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
646 let t1 = test_table_scan_with_name("t1")?;
647 let t2 = test_table_scan_with_name("t2")?;
648 let t3 = test_table_scan_with_name("t3")?;
649
650 let plan = LogicalPlanBuilder::from(t1)
652 .join(
653 t3,
654 JoinType::Inner,
655 (vec!["t1.a"], vec!["t3.a"]),
656 Some(col("t1.a").gt(lit(20u32))),
657 )?
658 .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
659 .filter(col("t1.a").gt(lit(15u32)))?
660 .build()?;
661
662 assert_optimized_plan_equal!(
663 plan,
664 @ r"
665 Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
666 Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
667 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
668 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
669 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
670 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
671 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
672 "
673 )
674 }
675
676 #[test]
677 fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
693 let t1 = test_table_scan_with_name("t1")?;
694 let t2 = test_table_scan_with_name("t2")?;
695 let t3 = test_table_scan_with_name("t3")?;
696
697 let plan = LogicalPlanBuilder::from(t1)
699 .cross_join(t2)?
700 .cross_join(t3)?
701 .filter(binary_expr(
702 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
703 And,
704 binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
705 ))?
706 .build()?;
707
708 assert_optimized_plan_equal!(
709 plan,
710 @ r"
711 Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
712 Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
713 Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
714 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
715 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
716 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
717 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
718 "
719 )
720 }
721
722 #[test]
723 fn eliminate_cross_join_multi_tables() -> Result<()> {
724 let t1 = test_table_scan_with_name("t1")?;
725 let t2 = test_table_scan_with_name("t2")?;
726 let t3 = test_table_scan_with_name("t3")?;
727 let t4 = test_table_scan_with_name("t4")?;
728
729 let plan1 = LogicalPlanBuilder::from(t1)
731 .cross_join(t2)?
732 .filter(binary_expr(
733 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
734 Or,
735 binary_expr(
736 col("t1.a").eq(col("t2.a")),
737 And,
738 col("t2.c").eq(lit(688u32)),
739 ),
740 ))?
741 .build()?;
742
743 let plan2 = LogicalPlanBuilder::from(t3)
744 .cross_join(t4)?
745 .filter(binary_expr(
746 binary_expr(
747 binary_expr(
748 col("t3.a").eq(col("t4.a")),
749 And,
750 col("t4.c").lt(lit(15u32)),
751 ),
752 Or,
753 binary_expr(
754 col("t3.a").eq(col("t4.a")),
755 And,
756 col("t3.c").eq(lit(688u32)),
757 ),
758 ),
759 Or,
760 binary_expr(
761 col("t3.a").eq(col("t4.a")),
762 And,
763 col("t3.b").eq(col("t4.b")),
764 ),
765 ))?
766 .build()?;
767
768 let plan = LogicalPlanBuilder::from(plan1)
769 .cross_join(plan2)?
770 .filter(binary_expr(
771 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
772 Or,
773 binary_expr(
774 col("t3.a").eq(col("t1.a")),
775 And,
776 col("t4.c").eq(lit(688u32)),
777 ),
778 ))?
779 .build()?;
780
781 assert_optimized_plan_equal!(
782 plan,
783 @ r"
784 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
785 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
786 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
787 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
788 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
789 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
790 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
791 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
792 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
793 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
794 "
795 )
796 }
797
798 #[test]
799 fn eliminate_cross_join_multi_tables_1() -> Result<()> {
800 let t1 = test_table_scan_with_name("t1")?;
801 let t2 = test_table_scan_with_name("t2")?;
802 let t3 = test_table_scan_with_name("t3")?;
803 let t4 = test_table_scan_with_name("t4")?;
804
805 let plan1 = LogicalPlanBuilder::from(t1)
807 .cross_join(t2)?
808 .filter(binary_expr(
809 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
810 Or,
811 binary_expr(
812 col("t1.a").eq(col("t2.a")),
813 And,
814 col("t2.c").eq(lit(688u32)),
815 ),
816 ))?
817 .build()?;
818
819 let plan2 = LogicalPlanBuilder::from(t3)
821 .cross_join(t4)?
822 .filter(binary_expr(
823 binary_expr(
824 binary_expr(
825 col("t3.a").eq(col("t4.a")),
826 And,
827 col("t4.c").lt(lit(15u32)),
828 ),
829 Or,
830 binary_expr(
831 col("t3.a").eq(col("t4.a")),
832 And,
833 col("t3.c").eq(lit(688u32)),
834 ),
835 ),
836 Or,
837 binary_expr(
838 col("t3.a").eq(col("t4.a")),
839 And,
840 col("t3.b").eq(col("t4.b")),
841 ),
842 ))?
843 .build()?;
844
845 let plan = LogicalPlanBuilder::from(plan1)
847 .cross_join(plan2)?
848 .filter(binary_expr(
849 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
850 Or,
851 binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
852 ))?
853 .build()?;
854
855 assert_optimized_plan_equal!(
856 plan,
857 @ r"
858 Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
859 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
860 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
861 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
862 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
863 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
864 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
865 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
866 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
867 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
868 "
869 )
870 }
871
872 #[test]
873 fn eliminate_cross_join_multi_tables_2() -> Result<()> {
874 let t1 = test_table_scan_with_name("t1")?;
875 let t2 = test_table_scan_with_name("t2")?;
876 let t3 = test_table_scan_with_name("t3")?;
877 let t4 = test_table_scan_with_name("t4")?;
878
879 let plan1 = LogicalPlanBuilder::from(t1)
881 .cross_join(t2)?
882 .filter(binary_expr(
883 binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
884 Or,
885 binary_expr(
886 col("t1.a").eq(col("t2.a")),
887 And,
888 col("t2.c").eq(lit(688u32)),
889 ),
890 ))?
891 .build()?;
892
893 let plan2 = LogicalPlanBuilder::from(t3)
895 .cross_join(t4)?
896 .filter(binary_expr(
897 binary_expr(
898 binary_expr(
899 col("t3.a").eq(col("t4.a")),
900 And,
901 col("t4.c").lt(lit(15u32)),
902 ),
903 Or,
904 binary_expr(
905 col("t3.a").eq(col("t4.a")),
906 And,
907 col("t3.c").eq(lit(688u32)),
908 ),
909 ),
910 Or,
911 binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
912 ))?
913 .build()?;
914
915 let plan = LogicalPlanBuilder::from(plan1)
917 .cross_join(plan2)?
918 .filter(binary_expr(
919 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
920 Or,
921 binary_expr(
922 col("t3.a").eq(col("t1.a")),
923 And,
924 col("t4.c").eq(lit(688u32)),
925 ),
926 ))?
927 .build()?;
928
929 assert_optimized_plan_equal!(
930 plan,
931 @ r"
932 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
933 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
934 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
935 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
936 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
937 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
938 Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
939 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
940 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
941 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
942 "
943 )
944 }
945
946 #[test]
947 fn eliminate_cross_join_multi_tables_3() -> Result<()> {
948 let t1 = test_table_scan_with_name("t1")?;
949 let t2 = test_table_scan_with_name("t2")?;
950 let t3 = test_table_scan_with_name("t3")?;
951 let t4 = test_table_scan_with_name("t4")?;
952
953 let plan1 = LogicalPlanBuilder::from(t1)
955 .cross_join(t2)?
956 .filter(binary_expr(
957 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
958 Or,
959 binary_expr(
960 col("t1.a").eq(col("t2.a")),
961 And,
962 col("t2.c").eq(lit(688u32)),
963 ),
964 ))?
965 .build()?;
966
967 let plan2 = LogicalPlanBuilder::from(t3)
969 .cross_join(t4)?
970 .filter(binary_expr(
971 binary_expr(
972 binary_expr(
973 col("t3.a").eq(col("t4.a")),
974 And,
975 col("t4.c").lt(lit(15u32)),
976 ),
977 Or,
978 binary_expr(
979 col("t3.a").eq(col("t4.a")),
980 And,
981 col("t3.c").eq(lit(688u32)),
982 ),
983 ),
984 Or,
985 binary_expr(
986 col("t3.a").eq(col("t4.a")),
987 And,
988 col("t3.b").eq(col("t4.b")),
989 ),
990 ))?
991 .build()?;
992
993 let plan = LogicalPlanBuilder::from(plan1)
995 .cross_join(plan2)?
996 .filter(binary_expr(
997 binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
998 Or,
999 binary_expr(
1000 col("t3.a").eq(col("t1.a")),
1001 And,
1002 col("t4.c").eq(lit(688u32)),
1003 ),
1004 ))?
1005 .build()?;
1006
1007 assert_optimized_plan_equal!(
1008 plan,
1009 @ r"
1010 Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1011 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1012 Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1013 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1014 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1015 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1016 Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1017 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1018 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1019 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1020 "
1021 )
1022 }
1023
1024 #[test]
1025 fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1026 let t1 = test_table_scan_with_name("t1")?;
1027 let t2 = test_table_scan_with_name("t2")?;
1028 let t3 = test_table_scan_with_name("t3")?;
1029 let t4 = test_table_scan_with_name("t4")?;
1030
1031 let plan1 = LogicalPlanBuilder::from(t1)
1034 .cross_join(t2)?
1035 .filter(binary_expr(
1036 binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1037 And,
1038 binary_expr(
1039 col("t1.a").eq(col("t2.a")),
1040 And,
1041 col("t2.c").eq(lit(688u32)),
1042 ),
1043 ))?
1044 .build()?;
1045
1046 let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1048
1049 let plan = LogicalPlanBuilder::from(plan1)
1055 .cross_join(plan2)?
1056 .filter(binary_expr(
1057 binary_expr(
1058 binary_expr(
1059 col("t3.a").eq(col("t1.a")),
1060 And,
1061 col("t4.c").lt(lit(15u32)),
1062 ),
1063 Or,
1064 binary_expr(
1065 col("t3.a").eq(col("t1.a")),
1066 And,
1067 col("t4.c").eq(lit(688u32)),
1068 ),
1069 ),
1070 And,
1071 binary_expr(
1072 binary_expr(
1073 binary_expr(
1074 col("t3.a").eq(col("t4.a")),
1075 And,
1076 col("t4.c").lt(lit(15u32)),
1077 ),
1078 Or,
1079 binary_expr(
1080 col("t3.a").eq(col("t4.a")),
1081 And,
1082 col("t3.c").eq(lit(688u32)),
1083 ),
1084 ),
1085 Or,
1086 binary_expr(
1087 col("t3.a").eq(col("t4.a")),
1088 And,
1089 col("t3.b").eq(col("t4.b")),
1090 ),
1091 ),
1092 ))?
1093 .build()?;
1094
1095 assert_optimized_plan_equal!(
1096 plan,
1097 @ r"
1098 Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1099 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1100 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1101 Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1102 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1103 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1104 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1105 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1106 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1107 "
1108 )
1109 }
1110
1111 #[test]
1112 fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1113 let t1 = test_table_scan_with_name("t1")?;
1114 let t2 = test_table_scan_with_name("t2")?;
1115 let t3 = test_table_scan_with_name("t3")?;
1116 let t4 = test_table_scan_with_name("t4")?;
1117
1118 let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1120
1121 let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1123
1124 let plan = LogicalPlanBuilder::from(plan1)
1132 .cross_join(plan2)?
1133 .filter(binary_expr(
1134 binary_expr(
1135 binary_expr(
1136 binary_expr(
1137 col("t3.a").eq(col("t1.a")),
1138 And,
1139 col("t4.c").lt(lit(15u32)),
1140 ),
1141 Or,
1142 binary_expr(
1143 col("t3.a").eq(col("t1.a")),
1144 And,
1145 col("t4.c").eq(lit(688u32)),
1146 ),
1147 ),
1148 And,
1149 binary_expr(
1150 binary_expr(
1151 binary_expr(
1152 col("t3.a").eq(col("t4.a")),
1153 And,
1154 col("t4.c").lt(lit(15u32)),
1155 ),
1156 Or,
1157 binary_expr(
1158 col("t3.a").eq(col("t4.a")),
1159 And,
1160 col("t3.c").eq(lit(688u32)),
1161 ),
1162 ),
1163 Or,
1164 binary_expr(
1165 col("t3.a").eq(col("t4.a")),
1166 And,
1167 col("t3.b").eq(col("t4.b")),
1168 ),
1169 ),
1170 ),
1171 And,
1172 binary_expr(
1173 binary_expr(
1174 col("t1.a").eq(col("t2.a")),
1175 Or,
1176 col("t2.c").lt(lit(15u32)),
1177 ),
1178 And,
1179 binary_expr(
1180 col("t1.a").eq(col("t2.a")),
1181 And,
1182 col("t2.c").eq(lit(688u32)),
1183 ),
1184 ),
1185 ))?
1186 .build()?;
1187
1188 assert_optimized_plan_equal!(
1189 plan,
1190 @ r"
1191 Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1192 Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1193 Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1194 Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1195 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1196 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1197 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1198 TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1199 "
1200 )
1201 }
1202
1203 #[test]
1204 fn eliminate_cross_join_with_expr_and() -> Result<()> {
1205 let t1 = test_table_scan_with_name("t1")?;
1206 let t2 = test_table_scan_with_name("t2")?;
1207
1208 let plan = LogicalPlanBuilder::from(t1)
1210 .cross_join(t2)?
1211 .filter(binary_expr(
1212 (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1213 And,
1214 col("t2.c").lt(lit(20u32)),
1215 ))?
1216 .build()?;
1217
1218 assert_optimized_plan_equal!(
1219 plan,
1220 @ r"
1221 Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1222 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1223 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1224 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1225 "
1226 )
1227 }
1228
1229 #[test]
1230 fn eliminate_cross_with_expr_or() -> Result<()> {
1231 let t1 = test_table_scan_with_name("t1")?;
1232 let t2 = test_table_scan_with_name("t2")?;
1233
1234 let plan = LogicalPlanBuilder::from(t1)
1237 .cross_join(t2)?
1238 .filter(binary_expr(
1239 (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1240 Or,
1241 col("t2.b").eq(col("t1.a")),
1242 ))?
1243 .build()?;
1244
1245 assert_optimized_plan_equal!(
1246 plan,
1247 @ r"
1248 Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1249 Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1250 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1251 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1252 "
1253 )
1254 }
1255
1256 #[test]
1257 fn eliminate_cross_with_common_expr_and() -> Result<()> {
1258 let t1 = test_table_scan_with_name("t1")?;
1259 let t2 = test_table_scan_with_name("t2")?;
1260
1261 let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1263 let plan = LogicalPlanBuilder::from(t1)
1264 .cross_join(t2)?
1265 .filter(binary_expr(
1266 binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1267 And,
1268 binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1269 ))?
1270 .build()?;
1271
1272 assert_optimized_plan_equal!(
1273 plan,
1274 @ r"
1275 Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1276 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1277 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1278 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1279 "
1280 )
1281 }
1282
1283 #[test]
1284 fn eliminate_cross_with_common_expr_or() -> Result<()> {
1285 let t1 = test_table_scan_with_name("t1")?;
1286 let t2 = test_table_scan_with_name("t2")?;
1287
1288 let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1290 let plan = LogicalPlanBuilder::from(t1)
1291 .cross_join(t2)?
1292 .filter(binary_expr(
1293 binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1294 Or,
1295 binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1296 ))?
1297 .build()?;
1298
1299 assert_optimized_plan_equal!(
1300 plan,
1301 @ r"
1302 Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1303 Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1304 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1305 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1306 "
1307 )
1308 }
1309
1310 #[test]
1311 fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1312 let t1 = test_table_scan_with_name("t1")?;
1313 let t2 = test_table_scan_with_name("t2")?;
1314 let t3 = test_table_scan_with_name("t3")?;
1315
1316 let plan = LogicalPlanBuilder::from(t1)
1318 .cross_join(t2)?
1319 .cross_join(t3)?
1320 .filter(binary_expr(
1321 binary_expr(
1322 (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1323 And,
1324 col("t3.c").lt(lit(15u32)),
1325 ),
1326 And,
1327 binary_expr(
1328 (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1329 And,
1330 col("t3.b").lt(lit(15u32)),
1331 ),
1332 ))?
1333 .build()?;
1334
1335 assert_optimized_plan_equal!(
1336 plan,
1337 @ r"
1338 Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1339 Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1340 Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1341 Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1342 TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1343 TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1344 TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1345 "
1346 )
1347 }
1348
1349 #[test]
1350 fn preserve_null_equality_setting() -> Result<()> {
1351 let t1 = test_table_scan_with_name("t1")?;
1352 let t2 = test_table_scan_with_name("t2")?;
1353
1354 let join_schema = Arc::new(build_join_schema(
1356 t1.schema(),
1357 t2.schema(),
1358 &JoinType::Inner,
1359 )?);
1360
1361 let inner_join = LogicalPlan::Join(Join {
1362 left: Arc::new(t1),
1363 right: Arc::new(t2),
1364 join_type: JoinType::Inner,
1365 join_constraint: JoinConstraint::On,
1366 on: vec![],
1367 filter: None,
1368 schema: join_schema,
1369 null_equality: NullEquality::NullEqualsNull, });
1371
1372 let plan = LogicalPlanBuilder::from(inner_join)
1374 .filter(binary_expr(
1375 col("t1.a").eq(col("t2.a")),
1376 And,
1377 col("t2.c").lt(lit(20u32)),
1378 ))?
1379 .build()?;
1380
1381 let rule = EliminateCrossJoin::new();
1382 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1383
1384 fn check_null_equality_preserved(plan: &LogicalPlan) -> bool {
1386 match plan {
1387 LogicalPlan::Join(join) => {
1388 if join.null_equality == NullEquality::NullEqualsNothing {
1390 return false;
1391 }
1392 plan.inputs()
1394 .iter()
1395 .all(|input| check_null_equality_preserved(input))
1396 }
1397 _ => {
1398 plan.inputs()
1400 .iter()
1401 .all(|input| check_null_equality_preserved(input))
1402 }
1403 }
1404 }
1405
1406 assert!(
1407 check_null_equality_preserved(&optimized_plan),
1408 "null_equality setting should be preserved after optimization"
1409 );
1410
1411 Ok(())
1412 }
1413}