1use std::cmp::{min, Ordering};
21use std::collections::HashSet;
22use std::fmt::{self, Debug};
23use std::future::Future;
24use std::iter::once;
25use std::ops::Range;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use crate::joins::SharedBitmapBuilder;
30use crate::metrics::{self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder};
31use crate::projection::{ProjectionExec, ProjectionExpr};
32use crate::{
33 ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics,
34};
35pub use super::join_filter::JoinFilter;
37pub use super::join_hash_map::JoinHashMapType;
38pub use crate::joins::{JoinOn, JoinOnRef};
39
40use ahash::RandomState;
41use arrow::array::{
42 builder::UInt64Builder, downcast_array, new_null_array, Array, ArrowPrimitiveType,
43 BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions,
44 UInt32Array, UInt32Builder, UInt64Array,
45};
46use arrow::array::{
47 ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
48 Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array,
49 Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray,
50 StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray,
51 TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array,
52};
53use arrow::buffer::{BooleanBuffer, NullBuffer};
54use arrow::compute::kernels::cmp::eq;
55use arrow::compute::{self, and, take, FilterBuilder};
56use arrow::datatypes::{
57 ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type,
58};
59use arrow_ord::cmp::not_distinct;
60use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit};
61use datafusion_common::cast::as_boolean_array;
62use datafusion_common::hash_utils::create_hashes;
63use datafusion_common::stats::Precision;
64use datafusion_common::{
65 not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result,
66 SharedResult,
67};
68use datafusion_expr::interval_arithmetic::Interval;
69use datafusion_expr::Operator;
70use datafusion_physical_expr::expressions::Column;
71use datafusion_physical_expr::utils::collect_columns;
72use datafusion_physical_expr::{
73 add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr,
74 PhysicalExprRef,
75};
76
77use datafusion_physical_expr_common::datum::compare_op_for_nested;
78use futures::future::{BoxFuture, Shared};
79use futures::{ready, FutureExt};
80use parking_lot::Mutex;
81
82pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
85 let left: HashSet<Column> = left
86 .fields()
87 .iter()
88 .enumerate()
89 .map(|(idx, f)| Column::new(f.name(), idx))
90 .collect();
91 let right: HashSet<Column> = right
92 .fields()
93 .iter()
94 .enumerate()
95 .map(|(idx, f)| Column::new(f.name(), idx))
96 .collect();
97
98 check_join_set_is_valid(&left, &right, on)
99}
100
101fn check_join_set_is_valid(
104 left: &HashSet<Column>,
105 right: &HashSet<Column>,
106 on: &[(PhysicalExprRef, PhysicalExprRef)],
107) -> Result<()> {
108 let on_left = &on
109 .iter()
110 .flat_map(|on| collect_columns(&on.0))
111 .collect::<HashSet<_>>();
112 let left_missing = on_left.difference(left).collect::<HashSet<_>>();
113
114 let on_right = &on
115 .iter()
116 .flat_map(|on| collect_columns(&on.1))
117 .collect::<HashSet<_>>();
118 let right_missing = on_right.difference(right).collect::<HashSet<_>>();
119
120 if !left_missing.is_empty() | !right_missing.is_empty() {
121 return plan_err!(
122 "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}"
123 );
124 };
125
126 Ok(())
127}
128
129pub fn adjust_right_output_partitioning(
131 right_partitioning: &Partitioning,
132 left_columns_len: usize,
133) -> Result<Partitioning> {
134 let result = match right_partitioning {
135 Partitioning::Hash(exprs, size) => {
136 let new_exprs = exprs
137 .iter()
138 .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _))
139 .collect::<Result<_>>()?;
140 Partitioning::Hash(new_exprs, *size)
141 }
142 result => result.clone(),
143 };
144 Ok(result)
145}
146
147pub fn calculate_join_output_ordering(
149 left_ordering: Option<&LexOrdering>,
150 right_ordering: Option<&LexOrdering>,
151 join_type: JoinType,
152 left_columns_len: usize,
153 maintains_input_order: &[bool],
154 probe_side: Option<JoinSide>,
155) -> Result<Option<LexOrdering>> {
156 match maintains_input_order {
157 [true, false] => {
158 if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) {
160 if let Some(right_ordering) = right_ordering.cloned() {
161 let right_offset = add_offset_to_physical_sort_exprs(
162 right_ordering,
163 left_columns_len as _,
164 )?;
165 return if let Some(left_ordering) = left_ordering {
166 let mut result = left_ordering.clone();
167 result.extend(right_offset);
168 Ok(Some(result))
169 } else {
170 Ok(LexOrdering::new(right_offset))
171 };
172 }
173 }
174 Ok(left_ordering.cloned())
175 }
176 [false, true] => {
177 if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) {
179 return if let Some(right_ordering) = right_ordering.cloned() {
180 let mut right_offset = add_offset_to_physical_sort_exprs(
181 right_ordering,
182 left_columns_len as _,
183 )?;
184 if let Some(left_ordering) = left_ordering {
185 right_offset.extend(left_ordering.clone());
186 }
187 Ok(LexOrdering::new(right_offset))
188 } else {
189 Ok(left_ordering.cloned())
190 };
191 }
192 let Some(right_ordering) = right_ordering else {
193 return Ok(None);
194 };
195 match join_type {
196 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
197 add_offset_to_physical_sort_exprs(
198 right_ordering.clone(),
199 left_columns_len as _,
200 )
201 .map(LexOrdering::new)
202 }
203 _ => Ok(Some(right_ordering.clone())),
204 }
205 }
206 [false, false] => Ok(None),
208 [true, true] => unreachable!("Cannot maintain ordering of both sides"),
209 _ => unreachable!("Join operators can not have more than two children"),
210 }
211}
212
213#[derive(Debug, Clone, PartialEq)]
215pub struct ColumnIndex {
216 pub index: usize,
218 pub side: JoinSide,
220}
221
222fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
225 let force_nullable = match join_type {
226 JoinType::Inner => false,
227 JoinType::Left => !is_left, JoinType::Right => is_left, JoinType::Full => true, JoinType::LeftSemi => false, JoinType::RightSemi => false, JoinType::LeftAnti => false, JoinType::RightAnti => false, JoinType::LeftMark => false,
235 JoinType::RightMark => false,
236 };
237
238 if force_nullable {
239 old_field.clone().with_nullable(true)
240 } else {
241 old_field.clone()
242 }
243}
244
245pub fn build_join_schema(
248 left: &Schema,
249 right: &Schema,
250 join_type: &JoinType,
251) -> (Schema, Vec<ColumnIndex>) {
252 let left_fields = || {
253 left.fields()
254 .iter()
255 .map(|f| output_join_field(f, join_type, true))
256 .enumerate()
257 .map(|(index, f)| {
258 (
259 f,
260 ColumnIndex {
261 index,
262 side: JoinSide::Left,
263 },
264 )
265 })
266 };
267
268 let right_fields = || {
269 right
270 .fields()
271 .iter()
272 .map(|f| output_join_field(f, join_type, false))
273 .enumerate()
274 .map(|(index, f)| {
275 (
276 f,
277 ColumnIndex {
278 index,
279 side: JoinSide::Right,
280 },
281 )
282 })
283 };
284
285 let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
286 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
287 left_fields().chain(right_fields()).unzip()
289 }
290 JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(),
291 JoinType::LeftMark => {
292 let right_field = once((
293 Field::new("mark", DataType::Boolean, false),
294 ColumnIndex {
295 index: 0,
296 side: JoinSide::None,
297 },
298 ));
299 left_fields().chain(right_field).unzip()
300 }
301 JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(),
302 JoinType::RightMark => {
303 let left_field = once((
304 Field::new("mark", DataType::Boolean, false),
305 ColumnIndex {
306 index: 0,
307 side: JoinSide::None,
308 },
309 ));
310 right_fields().chain(left_field).unzip()
311 }
312 };
313
314 let (schema1, schema2) = match join_type {
315 JoinType::Right
316 | JoinType::RightSemi
317 | JoinType::RightAnti
318 | JoinType::RightMark => (left, right),
319 _ => (right, left),
320 };
321
322 let metadata = schema1
323 .metadata()
324 .clone()
325 .into_iter()
326 .chain(schema2.metadata().clone())
327 .collect();
328
329 (fields.finish().with_metadata(metadata), column_indices)
330}
331
332pub(crate) struct OnceAsync<T> {
346 fut: Mutex<Option<SharedResult<OnceFut<T>>>>,
347}
348
349impl<T> Default for OnceAsync<T> {
350 fn default() -> Self {
351 Self {
352 fut: Mutex::new(None),
353 }
354 }
355}
356
357impl<T> Debug for OnceAsync<T> {
358 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359 write!(f, "OnceAsync")
360 }
361}
362
363impl<T: 'static> OnceAsync<T> {
364 pub(crate) fn try_once<F, Fut>(&self, f: F) -> Result<OnceFut<T>>
372 where
373 F: FnOnce() -> Result<Fut>,
374 Fut: Future<Output = Result<T>> + Send + 'static,
375 {
376 self.fut
377 .lock()
378 .get_or_insert_with(|| f().map(OnceFut::new).map_err(Arc::new))
379 .clone()
380 .map_err(DataFusionError::Shared)
381 }
382}
383
384type OnceFutPending<T> = Shared<BoxFuture<'static, SharedResult<Arc<T>>>>;
386
387pub(crate) struct OnceFut<T> {
391 state: OnceFutState<T>,
392}
393
394impl<T> Clone for OnceFut<T> {
395 fn clone(&self) -> Self {
396 Self {
397 state: self.state.clone(),
398 }
399 }
400}
401
402#[derive(Clone, Debug, Default)]
405struct PartialJoinStatistics {
406 pub num_rows: usize,
407 pub column_statistics: Vec<ColumnStatistics>,
408}
409
410pub(crate) fn estimate_join_statistics(
412 left_stats: Statistics,
413 right_stats: Statistics,
414 on: JoinOn,
415 join_type: &JoinType,
416 schema: &Schema,
417) -> Result<Statistics> {
418 let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on);
419 let (num_rows, column_statistics) = match join_stats {
420 Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics),
421 None => (Precision::Absent, Statistics::unknown_column(schema)),
422 };
423 Ok(Statistics {
424 num_rows,
425 total_byte_size: Precision::Absent,
426 column_statistics,
427 })
428}
429
430fn estimate_join_cardinality(
432 join_type: &JoinType,
433 left_stats: Statistics,
434 right_stats: Statistics,
435 on: &JoinOn,
436) -> Option<PartialJoinStatistics> {
437 let (left_col_stats, right_col_stats) = on
438 .iter()
439 .map(|(left, right)| {
440 match (
441 left.as_any().downcast_ref::<Column>(),
442 right.as_any().downcast_ref::<Column>(),
443 ) {
444 (Some(left), Some(right)) => (
445 left_stats.column_statistics[left.index()].clone(),
446 right_stats.column_statistics[right.index()].clone(),
447 ),
448 _ => (
449 ColumnStatistics::new_unknown(),
450 ColumnStatistics::new_unknown(),
451 ),
452 }
453 })
454 .unzip::<_, _, Vec<_>, Vec<_>>();
455
456 match join_type {
457 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
458 let ij_cardinality = estimate_inner_join_cardinality(
459 Statistics {
460 num_rows: left_stats.num_rows,
461 total_byte_size: Precision::Absent,
462 column_statistics: left_col_stats,
463 },
464 Statistics {
465 num_rows: right_stats.num_rows,
466 total_byte_size: Precision::Absent,
467 column_statistics: right_col_stats,
468 },
469 )?;
470
471 let cardinality = match join_type {
476 JoinType::Inner => ij_cardinality,
477 JoinType::Left => ij_cardinality.max(&left_stats.num_rows),
478 JoinType::Right => ij_cardinality.max(&right_stats.num_rows),
479 JoinType::Full => ij_cardinality
480 .max(&left_stats.num_rows)
481 .add(&ij_cardinality.max(&right_stats.num_rows))
482 .sub(&ij_cardinality),
483 _ => unreachable!(),
484 };
485
486 Some(PartialJoinStatistics {
487 num_rows: *cardinality.get_value()?,
488 column_statistics: left_stats
493 .column_statistics
494 .into_iter()
495 .chain(right_stats.column_statistics)
496 .collect(),
497 })
498 }
499
500 JoinType::LeftSemi | JoinType::RightSemi => {
504 let (outer_stats, inner_stats) = match join_type {
505 JoinType::LeftSemi => (left_stats, right_stats),
506 _ => (right_stats, left_stats),
507 };
508 let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) {
509 Some(estimation) => *estimation.get_value()?,
510 None => *outer_stats.num_rows.get_value()?,
511 };
512
513 Some(PartialJoinStatistics {
514 num_rows: cardinality,
515 column_statistics: outer_stats.column_statistics,
516 })
517 }
518
519 JoinType::LeftAnti | JoinType::RightAnti => {
522 let outer_stats = match join_type {
523 JoinType::LeftAnti => left_stats,
524 _ => right_stats,
525 };
526
527 Some(PartialJoinStatistics {
528 num_rows: *outer_stats.num_rows.get_value()?,
529 column_statistics: outer_stats.column_statistics,
530 })
531 }
532
533 JoinType::LeftMark => {
534 let num_rows = *left_stats.num_rows.get_value()?;
535 let mut column_statistics = left_stats.column_statistics;
536 column_statistics.push(ColumnStatistics::new_unknown());
537 Some(PartialJoinStatistics {
538 num_rows,
539 column_statistics,
540 })
541 }
542 JoinType::RightMark => {
543 let num_rows = *right_stats.num_rows.get_value()?;
544 let mut column_statistics = right_stats.column_statistics;
545 column_statistics.push(ColumnStatistics::new_unknown());
546 Some(PartialJoinStatistics {
547 num_rows,
548 column_statistics,
549 })
550 }
551 }
552}
553
554fn estimate_inner_join_cardinality(
559 left_stats: Statistics,
560 right_stats: Statistics,
561) -> Option<Precision<usize>> {
562 if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) {
564 return Some(estimation);
565 };
566
567 let mut join_selectivity = Precision::Absent;
570 for (left_stat, right_stat) in left_stats
571 .column_statistics
572 .iter()
573 .zip(right_stats.column_statistics.iter())
574 {
575 let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat);
576 let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat);
577 let max_distinct = left_max_distinct.max(&right_max_distinct);
578 if max_distinct.get_value().is_some() {
579 join_selectivity = max_distinct;
583 }
584 }
585
586 let left_num_rows = left_stats.num_rows.get_value()?;
590 let right_num_rows = right_stats.num_rows.get_value()?;
591 match join_selectivity {
592 Precision::Exact(value) if value > 0 => {
593 Some(Precision::Exact((left_num_rows * right_num_rows) / value))
594 }
595 Precision::Inexact(value) if value > 0 => {
596 Some(Precision::Inexact((left_num_rows * right_num_rows) / value))
597 }
598 _ => None,
603 }
604}
605
606fn estimate_disjoint_inputs(
609 left_stats: &Statistics,
610 right_stats: &Statistics,
611) -> Option<Precision<usize>> {
612 for (left_stat, right_stat) in left_stats
613 .column_statistics
614 .iter()
615 .zip(right_stats.column_statistics.iter())
616 {
617 let left_min_val = left_stat.min_value.get_value();
621 let right_max_val = right_stat.max_value.get_value();
622 if left_min_val.is_some()
623 && right_max_val.is_some()
624 && left_min_val > right_max_val
625 {
626 return Some(
627 if left_stat.min_value.is_exact().unwrap_or(false)
628 && right_stat.max_value.is_exact().unwrap_or(false)
629 {
630 Precision::Exact(0)
631 } else {
632 Precision::Inexact(0)
633 },
634 );
635 }
636
637 let left_max_val = left_stat.max_value.get_value();
638 let right_min_val = right_stat.min_value.get_value();
639 if left_max_val.is_some()
640 && right_min_val.is_some()
641 && left_max_val < right_min_val
642 {
643 return Some(
644 if left_stat.max_value.is_exact().unwrap_or(false)
645 && right_stat.min_value.is_exact().unwrap_or(false)
646 {
647 Precision::Exact(0)
648 } else {
649 Precision::Inexact(0)
650 },
651 );
652 }
653 }
654
655 None
656}
657
658fn max_distinct_count(
664 num_rows: &Precision<usize>,
665 stats: &ColumnStatistics,
666) -> Precision<usize> {
667 match &stats.distinct_count {
668 &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc,
669 _ => {
670 let result = match num_rows {
673 Precision::Absent => Precision::Absent,
674 Precision::Inexact(count) => {
675 match count.checked_sub(*stats.null_count.get_value().unwrap_or(&0)) {
678 None => Precision::Inexact(0),
679 Some(non_null_count) => Precision::Inexact(non_null_count),
680 }
681 }
682 Precision::Exact(count) => {
683 let count = count - stats.null_count.get_value().unwrap_or(&0);
684 if stats.null_count.is_exact().unwrap_or(false) {
685 Precision::Exact(count)
686 } else {
687 Precision::Inexact(count)
688 }
689 }
690 };
691 if let (Some(min), Some(max)) =
693 (stats.min_value.get_value(), stats.max_value.get_value())
694 {
695 if let Some(range_dc) = Interval::try_new(min.clone(), max.clone())
696 .ok()
697 .and_then(|e| e.cardinality())
698 {
699 let range_dc = range_dc as usize;
700 return if matches!(result, Precision::Absent)
702 || &range_dc < result.get_value().unwrap()
703 {
704 if stats.min_value.is_exact().unwrap()
705 && stats.max_value.is_exact().unwrap()
706 {
707 Precision::Exact(range_dc)
708 } else {
709 Precision::Inexact(range_dc)
710 }
711 } else {
712 result
713 };
714 }
715 }
716
717 result
718 }
719 }
720}
721
722enum OnceFutState<T> {
723 Pending(OnceFutPending<T>),
724 Ready(SharedResult<Arc<T>>),
725}
726
727impl<T> Clone for OnceFutState<T> {
728 fn clone(&self) -> Self {
729 match self {
730 Self::Pending(p) => Self::Pending(p.clone()),
731 Self::Ready(r) => Self::Ready(r.clone()),
732 }
733 }
734}
735
736impl<T: 'static> OnceFut<T> {
737 pub(crate) fn new<Fut>(fut: Fut) -> Self
739 where
740 Fut: Future<Output = Result<T>> + Send + 'static,
741 {
742 Self {
743 state: OnceFutState::Pending(
744 fut.map(|res| res.map(Arc::new).map_err(Arc::new))
745 .boxed()
746 .shared(),
747 ),
748 }
749 }
750
751 pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll<Result<&T>> {
753 if let OnceFutState::Pending(fut) = &mut self.state {
754 let r = ready!(fut.poll_unpin(cx));
755 self.state = OnceFutState::Ready(r);
756 }
757
758 match &self.state {
760 OnceFutState::Pending(_) => unreachable!(),
761 OnceFutState::Ready(r) => Poll::Ready(
762 r.as_ref()
763 .map(|r| r.as_ref())
764 .map_err(DataFusionError::from),
765 ),
766 }
767 }
768
769 pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll<Result<Arc<T>>> {
771 if let OnceFutState::Pending(fut) = &mut self.state {
772 let r = ready!(fut.poll_unpin(cx));
773 self.state = OnceFutState::Ready(r);
774 }
775
776 match &self.state {
777 OnceFutState::Pending(_) => unreachable!(),
778 OnceFutState::Ready(r) => {
779 Poll::Ready(r.clone().map_err(DataFusionError::Shared))
780 }
781 }
782 }
783}
784
785pub(crate) fn need_produce_right_in_final(join_type: JoinType) -> bool {
792 matches!(
793 join_type,
794 JoinType::Full
795 | JoinType::Right
796 | JoinType::RightAnti
797 | JoinType::RightMark
798 | JoinType::RightSemi
799 )
800}
801
802pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
808 matches!(
809 join_type,
810 JoinType::Left
811 | JoinType::LeftAnti
812 | JoinType::LeftSemi
813 | JoinType::LeftMark
814 | JoinType::Full
815 )
816}
817
818pub(crate) fn get_final_indices_from_shared_bitmap(
819 shared_bitmap: &SharedBitmapBuilder,
820 join_type: JoinType,
821 piecewise: bool,
822) -> (UInt64Array, UInt32Array) {
823 let bitmap = shared_bitmap.lock();
824 get_final_indices_from_bit_map(&bitmap, join_type, piecewise)
825}
826
827pub(crate) fn get_final_indices_from_bit_map(
837 left_bit_map: &BooleanBufferBuilder,
838 join_type: JoinType,
839 piecewise: bool,
842) -> (UInt64Array, UInt32Array) {
843 let left_size = left_bit_map.len();
844 if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise)
845 {
846 let left_indices = (0..left_size as u64).collect::<UInt64Array>();
847 let right_indices = (0..left_size)
848 .map(|idx| left_bit_map.get_bit(idx).then_some(0))
849 .collect::<UInt32Array>();
850 return (left_indices, right_indices);
851 }
852 let left_indices = if join_type == JoinType::LeftSemi
853 || (join_type == JoinType::RightSemi && piecewise)
854 {
855 (0..left_size)
856 .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
857 .collect::<UInt64Array>()
858 } else {
859 (0..left_size)
862 .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
863 .collect::<UInt64Array>()
864 };
865 let mut builder = UInt32Builder::with_capacity(left_indices.len());
868 builder.append_nulls(left_indices.len());
869 let right_indices = builder.finish();
870 (left_indices, right_indices)
871}
872
873pub(crate) fn apply_join_filter_to_indices(
874 build_input_buffer: &RecordBatch,
875 probe_batch: &RecordBatch,
876 build_indices: UInt64Array,
877 probe_indices: UInt32Array,
878 filter: &JoinFilter,
879 build_side: JoinSide,
880 max_intermediate_size: Option<usize>,
881) -> Result<(UInt64Array, UInt32Array)> {
882 if build_indices.is_empty() && probe_indices.is_empty() {
883 return Ok((build_indices, probe_indices));
884 };
885
886 let filter_result = if let Some(max_size) = max_intermediate_size {
887 let mut filter_results =
888 Vec::with_capacity(build_indices.len().div_ceil(max_size));
889
890 for i in (0..build_indices.len()).step_by(max_size) {
891 let end = min(build_indices.len(), i + max_size);
892 let len = end - i;
893 let intermediate_batch = build_batch_from_indices(
894 filter.schema(),
895 build_input_buffer,
896 probe_batch,
897 &build_indices.slice(i, len),
898 &probe_indices.slice(i, len),
899 filter.column_indices(),
900 build_side,
901 )?;
902 let filter_result = filter
903 .expression()
904 .evaluate(&intermediate_batch)?
905 .into_array(intermediate_batch.num_rows())?;
906 filter_results.push(filter_result);
907 }
908
909 let filter_refs: Vec<&dyn Array> =
910 filter_results.iter().map(|a| a.as_ref()).collect();
911
912 compute::concat(&filter_refs)?
913 } else {
914 let intermediate_batch = build_batch_from_indices(
915 filter.schema(),
916 build_input_buffer,
917 probe_batch,
918 &build_indices,
919 &probe_indices,
920 filter.column_indices(),
921 build_side,
922 )?;
923
924 filter
925 .expression()
926 .evaluate(&intermediate_batch)?
927 .into_array(intermediate_batch.num_rows())?
928 };
929
930 let mask = as_boolean_array(&filter_result)?;
931
932 let left_filtered = compute::filter(&build_indices, mask)?;
933 let right_filtered = compute::filter(&probe_indices, mask)?;
934 Ok((
935 downcast_array(left_filtered.as_ref()),
936 downcast_array(right_filtered.as_ref()),
937 ))
938}
939
940pub(crate) fn build_batch_from_indices(
943 schema: &Schema,
944 build_input_buffer: &RecordBatch,
945 probe_batch: &RecordBatch,
946 build_indices: &UInt64Array,
947 probe_indices: &UInt32Array,
948 column_indices: &[ColumnIndex],
949 build_side: JoinSide,
950) -> Result<RecordBatch> {
951 if schema.fields().is_empty() {
952 let options = RecordBatchOptions::new()
953 .with_match_field_names(true)
954 .with_row_count(Some(build_indices.len()));
955
956 return Ok(RecordBatch::try_new_with_options(
957 Arc::new(schema.clone()),
958 vec![],
959 &options,
960 )?);
961 }
962
963 let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
967
968 for column_index in column_indices {
969 let array = if column_index.side == JoinSide::None {
970 Arc::new(compute::is_not_null(probe_indices)?)
972 } else if column_index.side == build_side {
973 let array = build_input_buffer.column(column_index.index);
974 if array.is_empty() || build_indices.null_count() == build_indices.len() {
975 assert_eq!(build_indices.null_count(), build_indices.len());
979 new_null_array(array.data_type(), build_indices.len())
980 } else {
981 take(array.as_ref(), build_indices, None)?
982 }
983 } else {
984 let array = probe_batch.column(column_index.index);
985 if array.is_empty() || probe_indices.null_count() == probe_indices.len() {
986 assert_eq!(probe_indices.null_count(), probe_indices.len());
987 new_null_array(array.data_type(), probe_indices.len())
988 } else {
989 take(array.as_ref(), probe_indices, None)?
990 }
991 };
992
993 columns.push(array);
994 }
995 Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
996}
997
998pub(crate) fn build_batch_empty_build_side(
1001 schema: &Schema,
1002 build_batch: &RecordBatch,
1003 probe_batch: &RecordBatch,
1004 column_indices: &[ColumnIndex],
1005 join_type: JoinType,
1006) -> Result<RecordBatch> {
1007 match join_type {
1008 JoinType::Inner
1011 | JoinType::Left
1012 | JoinType::LeftSemi
1013 | JoinType::RightSemi
1014 | JoinType::LeftAnti
1015 | JoinType::LeftMark => Ok(RecordBatch::new_empty(Arc::new(schema.clone()))),
1016
1017 JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => {
1019 let num_rows = probe_batch.num_rows();
1020 let mut columns: Vec<Arc<dyn Array>> =
1021 Vec::with_capacity(schema.fields().len());
1022
1023 for column_index in column_indices {
1024 let array = match column_index.side {
1025 JoinSide::Left => new_null_array(
1027 build_batch.column(column_index.index).data_type(),
1028 num_rows,
1029 ),
1030 JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)),
1032 JoinSide::None => Arc::new(BooleanArray::new(
1034 BooleanBuffer::new_unset(num_rows),
1035 None,
1036 )),
1037 };
1038
1039 columns.push(array);
1040 }
1041
1042 Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
1043 }
1044 }
1045}
1046
1047pub(crate) fn adjust_indices_by_join_type(
1050 left_indices: UInt64Array,
1051 right_indices: UInt32Array,
1052 adjust_range: Range<usize>,
1053 join_type: JoinType,
1054 preserve_order_for_right: bool,
1055) -> Result<(UInt64Array, UInt32Array)> {
1056 match join_type {
1057 JoinType::Inner => {
1058 Ok((left_indices, right_indices))
1060 }
1061 JoinType::Left => {
1062 Ok((left_indices, right_indices))
1064 }
1066 JoinType::Right => {
1067 append_right_indices(
1069 left_indices,
1070 right_indices,
1071 adjust_range,
1072 preserve_order_for_right,
1073 )
1074 }
1075 JoinType::Full => {
1076 append_right_indices(left_indices, right_indices, adjust_range, false)
1077 }
1078 JoinType::RightSemi => {
1079 let right_indices = get_semi_indices(adjust_range, &right_indices);
1081 Ok((left_indices, right_indices))
1083 }
1084 JoinType::RightAnti => {
1085 let right_indices = get_anti_indices(adjust_range, &right_indices);
1088 Ok((left_indices, right_indices))
1090 }
1091 JoinType::RightMark => {
1092 let right_indices = get_mark_indices(&adjust_range, &right_indices);
1093 let left_indices_vec: Vec<u64> = adjust_range.map(|i| i as u64).collect();
1094 let left_indices = UInt64Array::from(left_indices_vec);
1095 Ok((left_indices, right_indices))
1096 }
1097 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1098 Ok((
1101 UInt64Array::from_iter_values(vec![]),
1102 UInt32Array::from_iter_values(vec![]),
1103 ))
1104 }
1105 }
1106}
1107
1108pub(crate) fn append_right_indices(
1124 left_indices: UInt64Array,
1125 right_indices: UInt32Array,
1126 adjust_range: Range<usize>,
1127 preserve_order_for_right: bool,
1128) -> Result<(UInt64Array, UInt32Array)> {
1129 if preserve_order_for_right {
1130 Ok(append_probe_indices_in_order(
1131 left_indices,
1132 right_indices,
1133 adjust_range,
1134 ))
1135 } else {
1136 let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices);
1137
1138 if right_unmatched_indices.is_empty() {
1139 Ok((left_indices, right_indices))
1140 } else {
1141 let mut new_left_indices_builder =
1147 left_indices.into_builder().unwrap_or_else(|left_indices| {
1148 let mut builder = UInt64Builder::with_capacity(
1149 left_indices.len() + right_unmatched_indices.len(),
1150 );
1151 debug_assert_eq!(
1152 left_indices.null_count(),
1153 0,
1154 "expected left indices to have no nulls"
1155 );
1156 builder.append_slice(left_indices.values());
1157 builder
1158 });
1159 new_left_indices_builder.append_nulls(right_unmatched_indices.len());
1160 let new_left_indices = UInt64Array::from(new_left_indices_builder.finish());
1161
1162 let mut new_right_indices_builder = right_indices
1164 .into_builder()
1165 .unwrap_or_else(|right_indices| {
1166 let mut builder = UInt32Builder::with_capacity(
1167 right_indices.len() + right_unmatched_indices.len(),
1168 );
1169 debug_assert_eq!(
1170 right_indices.null_count(),
1171 0,
1172 "expected right indices to have no nulls"
1173 );
1174 builder.append_slice(right_indices.values());
1175 builder
1176 });
1177 debug_assert_eq!(
1178 right_unmatched_indices.null_count(),
1179 0,
1180 "expected right unmatched indices to have no nulls"
1181 );
1182 new_right_indices_builder.append_slice(right_unmatched_indices.values());
1183 let new_right_indices = UInt32Array::from(new_right_indices_builder.finish());
1184
1185 Ok((new_left_indices, new_right_indices))
1186 }
1187 }
1188}
1189
1190pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
1192 range: Range<usize>,
1193 input_indices: &PrimitiveArray<T>,
1194) -> PrimitiveArray<T>
1195where
1196 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1197{
1198 let bitmap = build_range_bitmap(&range, input_indices);
1199 let offset = range.start;
1200
1201 (range)
1203 .filter_map(|idx| {
1204 (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1205 })
1206 .collect()
1207}
1208
1209pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
1211 range: Range<usize>,
1212 input_indices: &PrimitiveArray<T>,
1213) -> PrimitiveArray<T>
1214where
1215 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1216{
1217 let bitmap = build_range_bitmap(&range, input_indices);
1218 let offset = range.start;
1219 (range)
1221 .filter_map(|idx| {
1222 (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1223 })
1224 .collect()
1225}
1226
1227pub(crate) fn get_mark_indices<T: ArrowPrimitiveType>(
1228 range: &Range<usize>,
1229 input_indices: &PrimitiveArray<T>,
1230) -> PrimitiveArray<UInt32Type>
1231where
1232 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1233{
1234 let mut bitmap = build_range_bitmap(range, input_indices);
1235 PrimitiveArray::new(
1236 vec![0; range.len()].into(),
1237 Some(NullBuffer::new(bitmap.finish())),
1238 )
1239}
1240
1241fn build_range_bitmap<T: ArrowPrimitiveType>(
1242 range: &Range<usize>,
1243 input: &PrimitiveArray<T>,
1244) -> BooleanBufferBuilder {
1245 let mut builder = BooleanBufferBuilder::new(range.len());
1246 builder.append_n(range.len(), false);
1247
1248 input.iter().flatten().for_each(|v| {
1249 let idx = v.as_usize();
1250 if range.contains(&idx) {
1251 builder.set_bit(idx - range.start, true);
1252 }
1253 });
1254
1255 builder
1256}
1257
1258fn append_probe_indices_in_order(
1276 build_indices: PrimitiveArray<UInt64Type>,
1277 probe_indices: PrimitiveArray<UInt32Type>,
1278 range: Range<usize>,
1279) -> (PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>) {
1280 let mut new_build_indices = UInt64Builder::new();
1282 let mut new_probe_indices = UInt32Builder::new();
1283 let mut prev_index = range.start as u32;
1285 debug_assert!(build_indices.len() == probe_indices.len());
1287 for (build_index, probe_index) in build_indices
1288 .values()
1289 .into_iter()
1290 .zip(probe_indices.values().into_iter())
1291 {
1292 for value in prev_index..*probe_index {
1294 new_probe_indices.append_value(value);
1295 new_build_indices.append_null();
1296 }
1297 new_probe_indices.append_value(*probe_index);
1299 new_build_indices.append_value(*build_index);
1300 prev_index = probe_index + 1;
1302 }
1303 for value in prev_index..range.end as u32 {
1305 new_probe_indices.append_value(value);
1306 new_build_indices.append_null();
1307 }
1308 (new_build_indices.finish(), new_probe_indices.finish())
1310}
1311
1312#[derive(Clone, Debug)]
1314pub(crate) struct BuildProbeJoinMetrics {
1315 pub(crate) baseline: BaselineMetrics,
1316 pub(crate) build_time: metrics::Time,
1318 pub(crate) build_input_batches: metrics::Count,
1320 pub(crate) build_input_rows: metrics::Count,
1322 pub(crate) build_mem_used: metrics::Gauge,
1324 pub(crate) join_time: metrics::Time,
1326 pub(crate) input_batches: metrics::Count,
1328 pub(crate) input_rows: metrics::Count,
1330 pub(crate) output_batches: metrics::Count,
1332}
1333
1334impl Drop for BuildProbeJoinMetrics {
1347 fn drop(&mut self) {
1348 self.baseline.elapsed_compute().add(&self.build_time);
1349 self.baseline.elapsed_compute().add(&self.join_time);
1350 }
1351}
1352
1353impl BuildProbeJoinMetrics {
1354 pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
1355 let baseline = BaselineMetrics::new(metrics, partition);
1356
1357 let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
1358
1359 let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition);
1360
1361 let build_input_batches =
1362 MetricBuilder::new(metrics).counter("build_input_batches", partition);
1363
1364 let build_input_rows =
1365 MetricBuilder::new(metrics).counter("build_input_rows", partition);
1366
1367 let build_mem_used =
1368 MetricBuilder::new(metrics).gauge("build_mem_used", partition);
1369
1370 let input_batches =
1371 MetricBuilder::new(metrics).counter("input_batches", partition);
1372
1373 let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
1374
1375 let output_batches =
1376 MetricBuilder::new(metrics).counter("output_batches", partition);
1377
1378 Self {
1379 build_time,
1380 build_input_batches,
1381 build_input_rows,
1382 build_mem_used,
1383 join_time,
1384 input_batches,
1385 input_rows,
1386 output_batches,
1387 baseline,
1388 }
1389 }
1390}
1391
1392#[macro_export]
1411macro_rules! handle_state {
1412 ($match_case:expr) => {
1413 match $match_case {
1414 Ok(StatefulStreamResult::Continue) => continue,
1415 Ok(StatefulStreamResult::Ready(result)) => {
1416 Poll::Ready(Ok(result).transpose())
1417 }
1418 Err(e) => Poll::Ready(Some(Err(e))),
1419 }
1420 };
1421}
1422
1423pub enum StatefulStreamResult<T> {
1435 Ready(T),
1436 Continue,
1437}
1438
1439pub(crate) fn symmetric_join_output_partitioning(
1440 left: &Arc<dyn ExecutionPlan>,
1441 right: &Arc<dyn ExecutionPlan>,
1442 join_type: &JoinType,
1443) -> Result<Partitioning> {
1444 let left_columns_len = left.schema().fields.len();
1445 let left_partitioning = left.output_partitioning();
1446 let right_partitioning = right.output_partitioning();
1447 let result = match join_type {
1448 JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1449 left_partitioning.clone()
1450 }
1451 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1452 right_partitioning.clone()
1453 }
1454 JoinType::Inner | JoinType::Right => {
1455 adjust_right_output_partitioning(right_partitioning, left_columns_len)?
1456 }
1457 JoinType::Full => {
1458 Partitioning::UnknownPartitioning(right_partitioning.partition_count())
1460 }
1461 };
1462 Ok(result)
1463}
1464
1465pub(crate) fn asymmetric_join_output_partitioning(
1466 left: &Arc<dyn ExecutionPlan>,
1467 right: &Arc<dyn ExecutionPlan>,
1468 join_type: &JoinType,
1469) -> Result<Partitioning> {
1470 let result = match join_type {
1471 JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
1472 right.output_partitioning(),
1473 left.schema().fields().len(),
1474 )?,
1475 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1476 right.output_partitioning().clone()
1477 }
1478 JoinType::Left
1479 | JoinType::LeftSemi
1480 | JoinType::LeftAnti
1481 | JoinType::Full
1482 | JoinType::LeftMark => Partitioning::UnknownPartitioning(
1483 right.output_partitioning().partition_count(),
1484 ),
1485 };
1486 Ok(result)
1487}
1488
1489pub(crate) trait BatchTransformer: Debug + Clone {
1494 fn set_batch(&mut self, batch: RecordBatch);
1496
1497 fn next(&mut self) -> Option<(RecordBatch, bool)>;
1501}
1502
1503#[derive(Debug, Clone)]
1504pub(crate) struct NoopBatchTransformer {
1506 batch: Option<RecordBatch>,
1508}
1509
1510impl NoopBatchTransformer {
1511 pub fn new() -> Self {
1512 Self { batch: None }
1513 }
1514}
1515
1516impl BatchTransformer for NoopBatchTransformer {
1517 fn set_batch(&mut self, batch: RecordBatch) {
1518 self.batch = Some(batch);
1519 }
1520
1521 fn next(&mut self) -> Option<(RecordBatch, bool)> {
1522 self.batch.take().map(|batch| (batch, true))
1523 }
1524}
1525
1526#[derive(Debug, Clone)]
1527pub(crate) struct BatchSplitter {
1529 batch: Option<RecordBatch>,
1531 batch_size: usize,
1533 row_index: usize,
1535}
1536
1537impl BatchSplitter {
1538 pub(crate) fn new(batch_size: usize) -> Self {
1540 Self {
1541 batch: None,
1542 batch_size,
1543 row_index: 0,
1544 }
1545 }
1546}
1547
1548impl BatchTransformer for BatchSplitter {
1549 fn set_batch(&mut self, batch: RecordBatch) {
1550 self.batch = Some(batch);
1551 self.row_index = 0;
1552 }
1553
1554 fn next(&mut self) -> Option<(RecordBatch, bool)> {
1555 let Some(batch) = &self.batch else {
1556 return None;
1557 };
1558
1559 let remaining_rows = batch.num_rows() - self.row_index;
1560 let rows_to_slice = remaining_rows.min(self.batch_size);
1561 let sliced_batch = batch.slice(self.row_index, rows_to_slice);
1562 self.row_index += rows_to_slice;
1563
1564 let mut last = false;
1565 if self.row_index >= batch.num_rows() {
1566 self.batch = None;
1567 last = true;
1568 }
1569
1570 Some((sliced_batch, last))
1571 }
1572}
1573
1574pub fn reorder_output_after_swap(
1581 plan: Arc<dyn ExecutionPlan>,
1582 left_schema: &Schema,
1583 right_schema: &Schema,
1584) -> Result<Arc<dyn ExecutionPlan>> {
1585 let proj = ProjectionExec::try_new(
1586 swap_reverting_projection(left_schema, right_schema),
1587 plan,
1588 )?;
1589 Ok(Arc::new(proj))
1590}
1591
1592fn swap_reverting_projection(
1598 left_schema: &Schema,
1599 right_schema: &Schema,
1600) -> Vec<ProjectionExpr> {
1601 let right_cols =
1602 right_schema
1603 .fields()
1604 .iter()
1605 .enumerate()
1606 .map(|(i, f)| ProjectionExpr {
1607 expr: Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
1608 alias: f.name().to_owned(),
1609 });
1610 let right_len = right_cols.len();
1611 let left_cols =
1612 left_schema
1613 .fields()
1614 .iter()
1615 .enumerate()
1616 .map(|(i, f)| ProjectionExpr {
1617 expr: Arc::new(Column::new(f.name(), right_len + i))
1618 as Arc<dyn PhysicalExpr>,
1619 alias: f.name().to_owned(),
1620 });
1621
1622 left_cols.chain(right_cols).collect()
1623}
1624
1625pub fn swap_join_projection(
1627 left_schema_len: usize,
1628 right_schema_len: usize,
1629 projection: Option<&Vec<usize>>,
1630 join_type: &JoinType,
1631) -> Option<Vec<usize>> {
1632 match join_type {
1633 JoinType::LeftAnti
1636 | JoinType::LeftSemi
1637 | JoinType::RightAnti
1638 | JoinType::RightSemi
1639 | JoinType::LeftMark
1640 | JoinType::RightMark => projection.cloned(),
1641 _ => projection.map(|p| {
1642 p.iter()
1643 .map(|i| {
1644 if *i < left_schema_len {
1649 *i + right_schema_len
1650 } else {
1651 *i - left_schema_len
1652 }
1653 })
1654 .collect()
1655 }),
1656 }
1657}
1658
1659#[allow(clippy::too_many_arguments)]
1666pub fn update_hash(
1667 on: &[PhysicalExprRef],
1668 batch: &RecordBatch,
1669 hash_map: &mut dyn JoinHashMapType,
1670 offset: usize,
1671 random_state: &RandomState,
1672 hashes_buffer: &mut Vec<u64>,
1673 deleted_offset: usize,
1674 fifo_hashmap: bool,
1675) -> Result<()> {
1676 let keys_values = on
1678 .iter()
1679 .map(|c| c.evaluate(batch)?.into_array(batch.num_rows()))
1680 .collect::<Result<Vec<_>>>()?;
1681
1682 let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1684
1685 hash_map.extend_zero(batch.num_rows());
1687
1688 let hash_values_iter = hash_values
1690 .iter()
1691 .enumerate()
1692 .map(|(i, val)| (i + offset, val));
1693
1694 if fifo_hashmap {
1695 hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset);
1696 } else {
1697 hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset);
1698 }
1699
1700 Ok(())
1701}
1702
1703pub(super) fn equal_rows_arr(
1704 indices_left: &UInt64Array,
1705 indices_right: &UInt32Array,
1706 left_arrays: &[ArrayRef],
1707 right_arrays: &[ArrayRef],
1708 null_equality: NullEquality,
1709) -> Result<(UInt64Array, UInt32Array)> {
1710 let mut iter = left_arrays.iter().zip(right_arrays.iter());
1711
1712 let Some((first_left, first_right)) = iter.next() else {
1713 return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into()));
1714 };
1715
1716 let arr_left = take(first_left.as_ref(), indices_left, None)?;
1717 let arr_right = take(first_right.as_ref(), indices_right, None)?;
1718
1719 let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?;
1720
1721 equal = iter
1725 .map(|(left, right)| {
1726 let arr_left = take(left.as_ref(), indices_left, None)?;
1727 let arr_right = take(right.as_ref(), indices_right, None)?;
1728 eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality)
1729 })
1730 .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
1731
1732 let filter_builder = FilterBuilder::new(&equal).optimize().build();
1733
1734 let left_filtered = filter_builder.filter(indices_left)?;
1735 let right_filtered = filter_builder.filter(indices_right)?;
1736
1737 Ok((
1738 downcast_array(left_filtered.as_ref()),
1739 downcast_array(right_filtered.as_ref()),
1740 ))
1741}
1742
1743fn eq_dyn_null(
1745 left: &dyn Array,
1746 right: &dyn Array,
1747 null_equality: NullEquality,
1748) -> Result<BooleanArray, ArrowError> {
1749 if left.data_type().is_nested() {
1753 let op = match null_equality {
1754 NullEquality::NullEqualsNothing => Operator::Eq,
1755 NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
1756 };
1757 return Ok(compare_op_for_nested(op, &left, &right)?);
1758 }
1759 match null_equality {
1760 NullEquality::NullEqualsNothing => eq(&left, &right),
1761 NullEquality::NullEqualsNull => not_distinct(&left, &right),
1762 }
1763}
1764
1765pub fn compare_join_arrays(
1767 left_arrays: &[ArrayRef],
1768 left: usize,
1769 right_arrays: &[ArrayRef],
1770 right: usize,
1771 sort_options: &[SortOptions],
1772 null_equality: NullEquality,
1773) -> Result<Ordering> {
1774 let mut res = Ordering::Equal;
1775 for ((left_array, right_array), sort_options) in
1776 left_arrays.iter().zip(right_arrays).zip(sort_options)
1777 {
1778 macro_rules! compare_value {
1779 ($T:ty) => {{
1780 let left_array = left_array.as_any().downcast_ref::<$T>().unwrap();
1781 let right_array = right_array.as_any().downcast_ref::<$T>().unwrap();
1782 match (left_array.is_null(left), right_array.is_null(right)) {
1783 (false, false) => {
1784 let left_value = &left_array.value(left);
1785 let right_value = &right_array.value(right);
1786 res = left_value.partial_cmp(right_value).unwrap();
1787 if sort_options.descending {
1788 res = res.reverse();
1789 }
1790 }
1791 (true, false) => {
1792 res = if sort_options.nulls_first {
1793 Ordering::Less
1794 } else {
1795 Ordering::Greater
1796 };
1797 }
1798 (false, true) => {
1799 res = if sort_options.nulls_first {
1800 Ordering::Greater
1801 } else {
1802 Ordering::Less
1803 };
1804 }
1805 _ => {
1806 res = match null_equality {
1807 NullEquality::NullEqualsNothing => Ordering::Less,
1808 NullEquality::NullEqualsNull => Ordering::Equal,
1809 };
1810 }
1811 }
1812 }};
1813 }
1814
1815 match left_array.data_type() {
1816 DataType::Null => {}
1817 DataType::Boolean => compare_value!(BooleanArray),
1818 DataType::Int8 => compare_value!(Int8Array),
1819 DataType::Int16 => compare_value!(Int16Array),
1820 DataType::Int32 => compare_value!(Int32Array),
1821 DataType::Int64 => compare_value!(Int64Array),
1822 DataType::UInt8 => compare_value!(UInt8Array),
1823 DataType::UInt16 => compare_value!(UInt16Array),
1824 DataType::UInt32 => compare_value!(UInt32Array),
1825 DataType::UInt64 => compare_value!(UInt64Array),
1826 DataType::Float32 => compare_value!(Float32Array),
1827 DataType::Float64 => compare_value!(Float64Array),
1828 DataType::Binary => compare_value!(BinaryArray),
1829 DataType::BinaryView => compare_value!(BinaryViewArray),
1830 DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
1831 DataType::LargeBinary => compare_value!(LargeBinaryArray),
1832 DataType::Utf8 => compare_value!(StringArray),
1833 DataType::Utf8View => compare_value!(StringViewArray),
1834 DataType::LargeUtf8 => compare_value!(LargeStringArray),
1835 DataType::Decimal128(..) => compare_value!(Decimal128Array),
1836 DataType::Timestamp(time_unit, None) => match time_unit {
1837 TimeUnit::Second => compare_value!(TimestampSecondArray),
1838 TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
1839 TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
1840 TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
1841 },
1842 DataType::Date32 => compare_value!(Date32Array),
1843 DataType::Date64 => compare_value!(Date64Array),
1844 dt => {
1845 return not_impl_err!(
1846 "Unsupported data type in sort merge join comparator: {}",
1847 dt
1848 );
1849 }
1850 }
1851 if !res.is_eq() {
1852 break;
1853 }
1854 }
1855 Ok(res)
1856}
1857
1858#[cfg(test)]
1859mod tests {
1860 use std::collections::HashMap;
1861 use std::pin::Pin;
1862
1863 use super::*;
1864
1865 use arrow::array::Int32Array;
1866 use arrow::datatypes::{DataType, Fields};
1867 use arrow::error::{ArrowError, Result as ArrowResult};
1868 use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
1869 use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
1870 use datafusion_physical_expr::PhysicalSortExpr;
1871
1872 use rstest::rstest;
1873
1874 fn check(
1875 left: &[Column],
1876 right: &[Column],
1877 on: &[(PhysicalExprRef, PhysicalExprRef)],
1878 ) -> Result<()> {
1879 let left = left
1880 .iter()
1881 .map(|x| x.to_owned())
1882 .collect::<HashSet<Column>>();
1883 let right = right
1884 .iter()
1885 .map(|x| x.to_owned())
1886 .collect::<HashSet<Column>>();
1887 check_join_set_is_valid(&left, &right, on)
1888 }
1889
1890 #[test]
1891 fn check_valid() -> Result<()> {
1892 let left = vec![Column::new("a", 0), Column::new("b1", 1)];
1893 let right = vec![Column::new("a", 0), Column::new("b2", 1)];
1894 let on = &[(
1895 Arc::new(Column::new("a", 0)) as _,
1896 Arc::new(Column::new("a", 0)) as _,
1897 )];
1898
1899 check(&left, &right, on)?;
1900 Ok(())
1901 }
1902
1903 #[test]
1904 fn check_not_in_right() {
1905 let left = vec![Column::new("a", 0), Column::new("b", 1)];
1906 let right = vec![Column::new("b", 0)];
1907 let on = &[(
1908 Arc::new(Column::new("a", 0)) as _,
1909 Arc::new(Column::new("a", 0)) as _,
1910 )];
1911
1912 assert!(check(&left, &right, on).is_err());
1913 }
1914
1915 #[tokio::test]
1916 async fn check_error_nesting() {
1917 let once_fut = OnceFut::<()>::new(async {
1918 arrow_err!(ArrowError::CsvError("some error".to_string()))
1919 });
1920
1921 struct TestFut(OnceFut<()>);
1922 impl Future for TestFut {
1923 type Output = ArrowResult<()>;
1924
1925 fn poll(
1926 mut self: Pin<&mut Self>,
1927 cx: &mut Context<'_>,
1928 ) -> Poll<Self::Output> {
1929 match ready!(self.0.get(cx)) {
1930 Ok(()) => Poll::Ready(Ok(())),
1931 Err(e) => Poll::Ready(Err(e.into())),
1932 }
1933 }
1934 }
1935
1936 let res = TestFut(once_fut).await;
1937 let arrow_err_from_fut = res.expect_err("once_fut always return error");
1938
1939 let wrapped_err = DataFusionError::from(arrow_err_from_fut);
1940 let root_err = wrapped_err.find_root();
1941
1942 let _expected =
1943 arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned()));
1944
1945 assert!(matches!(root_err, _expected))
1946 }
1947
1948 #[test]
1949 fn check_not_in_left() {
1950 let left = vec![Column::new("b", 0)];
1951 let right = vec![Column::new("a", 0)];
1952 let on = &[(
1953 Arc::new(Column::new("a", 0)) as _,
1954 Arc::new(Column::new("a", 0)) as _,
1955 )];
1956
1957 assert!(check(&left, &right, on).is_err());
1958 }
1959
1960 #[test]
1961 fn check_collision() {
1962 let left = vec![Column::new("a", 0), Column::new("c", 1)];
1964 let right = vec![Column::new("a", 0), Column::new("b", 1)];
1965 let on = &[(
1966 Arc::new(Column::new("a", 0)) as _,
1967 Arc::new(Column::new("b", 1)) as _,
1968 )];
1969
1970 assert!(check(&left, &right, on).is_ok());
1971 }
1972
1973 #[test]
1974 fn check_in_right() {
1975 let left = vec![Column::new("a", 0), Column::new("c", 1)];
1976 let right = vec![Column::new("b", 0)];
1977 let on = &[(
1978 Arc::new(Column::new("a", 0)) as _,
1979 Arc::new(Column::new("b", 0)) as _,
1980 )];
1981
1982 assert!(check(&left, &right, on).is_ok());
1983 }
1984
1985 #[test]
1986 fn test_join_schema() -> Result<()> {
1987 let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1988 let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1989 let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
1990 let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
1991
1992 let cases = vec![
1993 (&a, &b, JoinType::Inner, &a, &b),
1994 (&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
1995 (&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
1996 (&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
1997 (&a, &b, JoinType::Left, &a, &b_nulls),
1999 (&a, &b_nulls, JoinType::Left, &a, &b_nulls),
2000 (&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
2001 (&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
2002 (&a, &b, JoinType::Right, &a_nulls, &b),
2004 (&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
2005 (&a_nulls, &b, JoinType::Right, &a_nulls, &b),
2006 (&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
2007 (&a, &b, JoinType::Full, &a_nulls, &b_nulls),
2009 (&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
2010 (&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
2011 (&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
2012 ];
2013
2014 for (left_in, right_in, join_type, left_out, right_out) in cases {
2015 let (schema, _) = build_join_schema(left_in, right_in, &join_type);
2016
2017 let expected_fields = left_out
2018 .fields()
2019 .iter()
2020 .cloned()
2021 .chain(right_out.fields().iter().cloned())
2022 .collect::<Fields>();
2023
2024 let expected_schema = Schema::new(expected_fields);
2025 assert_eq!(
2026 schema,
2027 expected_schema,
2028 "Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
2029 left_in.fields()[0].name(),
2030 left_in.fields()[0].is_nullable(),
2031 right_in.fields()[0].name(),
2032 right_in.fields()[0].is_nullable(),
2033 join_type
2034 );
2035 }
2036
2037 Ok(())
2038 }
2039
2040 fn create_stats(
2041 num_rows: Option<usize>,
2042 column_stats: Vec<ColumnStatistics>,
2043 is_exact: bool,
2044 ) -> Statistics {
2045 Statistics {
2046 num_rows: if is_exact {
2047 num_rows.map(Exact)
2048 } else {
2049 num_rows.map(Inexact)
2050 }
2051 .unwrap_or(Absent),
2052 column_statistics: column_stats,
2053 total_byte_size: Absent,
2054 }
2055 }
2056
2057 fn create_column_stats(
2058 min: Precision<i64>,
2059 max: Precision<i64>,
2060 distinct_count: Precision<usize>,
2061 null_count: Precision<usize>,
2062 ) -> ColumnStatistics {
2063 ColumnStatistics {
2064 distinct_count,
2065 min_value: min.map(ScalarValue::from),
2066 max_value: max.map(ScalarValue::from),
2067 sum_value: Absent,
2068 null_count,
2069 }
2070 }
2071
2072 type PartialStats = (
2073 usize,
2074 Precision<i64>,
2075 Precision<i64>,
2076 Precision<usize>,
2077 Precision<usize>,
2078 );
2079
2080 #[test]
2084 fn test_inner_join_cardinality_single_column() -> Result<()> {
2085 let cases: Vec<(PartialStats, PartialStats, Option<Precision<usize>>)> = vec![
2086 (
2097 (10, Inexact(1), Inexact(10), Absent, Absent),
2098 (10, Inexact(1), Inexact(10), Absent, Absent),
2099 Some(Inexact(10)),
2100 ),
2101 (
2103 (10, Inexact(6), Inexact(10), Absent, Absent),
2104 (10, Inexact(8), Inexact(10), Absent, Absent),
2105 Some(Inexact(20)),
2106 ),
2107 (
2109 (10, Inexact(8), Inexact(10), Absent, Absent),
2110 (10, Inexact(6), Inexact(10), Absent, Absent),
2111 Some(Inexact(20)),
2112 ),
2113 (
2115 (10, Inexact(1), Inexact(15), Absent, Absent),
2116 (20, Inexact(1), Inexact(40), Absent, Absent),
2117 Some(Inexact(10)),
2118 ),
2119 (
2121 (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2122 (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2123 Some(Inexact(10)),
2124 ),
2125 (
2127 (10, Inexact(1), Inexact(3), Inexact(10), Absent),
2128 (10, Inexact(1), Inexact(3), Inexact(10), Absent),
2129 Some(Inexact(10)),
2130 ),
2131 (
2133 (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2134 (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2135 Some(Inexact(20)),
2136 ),
2137 (
2139 (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2140 (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2141 Some(Inexact(20)),
2142 ),
2143 (
2145 (10, Inexact(-5), Inexact(5), Absent, Absent),
2146 (10, Inexact(1), Inexact(5), Absent, Absent),
2147 Some(Inexact(10)),
2148 ),
2149 (
2151 (10, Inexact(-25), Inexact(-20), Absent, Absent),
2152 (10, Inexact(-25), Inexact(-15), Absent, Absent),
2153 Some(Inexact(10)),
2154 ),
2155 (
2160 (10, Inexact(-10), Inexact(0), Absent, Absent),
2161 (10, Inexact(0), Inexact(10), Inexact(5), Absent),
2162 Some(Inexact(10)),
2163 ),
2164 (
2166 (10, Inexact(1), Inexact(1), Absent, Absent),
2167 (10, Inexact(1), Inexact(1), Absent, Absent),
2168 Some(Inexact(100)),
2169 ),
2170 (
2176 (10, Absent, Absent, Absent, Absent),
2177 (10, Absent, Absent, Absent, Absent),
2178 Some(Inexact(10)),
2179 ),
2180 (
2182 (10, Absent, Absent, Inexact(3), Absent),
2183 (10, Absent, Absent, Inexact(3), Absent),
2184 Some(Inexact(33)),
2185 ),
2186 (
2187 (10, Inexact(2), Absent, Inexact(3), Absent),
2188 (10, Absent, Inexact(5), Inexact(3), Absent),
2189 Some(Inexact(33)),
2190 ),
2191 (
2192 (10, Absent, Inexact(3), Inexact(3), Absent),
2193 (10, Inexact(1), Absent, Inexact(3), Absent),
2194 Some(Inexact(33)),
2195 ),
2196 (
2198 (10, Absent, Inexact(3), Absent, Absent),
2199 (10, Inexact(1), Absent, Absent, Absent),
2200 Some(Inexact(10)),
2201 ),
2202 (
2204 (10, Absent, Inexact(4), Absent, Absent),
2205 (10, Inexact(5), Absent, Absent, Absent),
2206 Some(Inexact(0)),
2207 ),
2208 (
2209 (10, Inexact(0), Inexact(10), Absent, Absent),
2210 (10, Inexact(11), Inexact(20), Absent, Absent),
2211 Some(Inexact(0)),
2212 ),
2213 (
2214 (10, Inexact(11), Inexact(20), Absent, Absent),
2215 (10, Inexact(0), Inexact(10), Absent, Absent),
2216 Some(Inexact(0)),
2217 ),
2218 (
2220 (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2221 (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2222 None,
2223 ),
2224 (
2226 (0, Inexact(1), Inexact(10), Absent, Exact(5)),
2227 (10, Inexact(1), Inexact(10), Absent, Absent),
2228 Some(Inexact(0)),
2229 ),
2230 ];
2231
2232 for (left_info, right_info, expected_cardinality) in cases {
2233 let left_num_rows = left_info.0;
2234 let left_col_stats = vec![create_column_stats(
2235 left_info.1,
2236 left_info.2,
2237 left_info.3,
2238 left_info.4,
2239 )];
2240
2241 let right_num_rows = right_info.0;
2242 let right_col_stats = vec![create_column_stats(
2243 right_info.1,
2244 right_info.2,
2245 right_info.3,
2246 right_info.4,
2247 )];
2248
2249 assert_eq!(
2250 estimate_inner_join_cardinality(
2251 Statistics {
2252 num_rows: Inexact(left_num_rows),
2253 total_byte_size: Absent,
2254 column_statistics: left_col_stats.clone(),
2255 },
2256 Statistics {
2257 num_rows: Inexact(right_num_rows),
2258 total_byte_size: Absent,
2259 column_statistics: right_col_stats.clone(),
2260 },
2261 ),
2262 expected_cardinality.clone()
2263 );
2264
2265 let join_type = JoinType::Inner;
2267 let join_on = vec![(
2268 Arc::new(Column::new("a", 0)) as _,
2269 Arc::new(Column::new("b", 0)) as _,
2270 )];
2271 let partial_join_stats = estimate_join_cardinality(
2272 &join_type,
2273 create_stats(Some(left_num_rows), left_col_stats.clone(), false),
2274 create_stats(Some(right_num_rows), right_col_stats.clone(), false),
2275 &join_on,
2276 );
2277
2278 assert_eq!(
2279 partial_join_stats.clone().map(|s| Inexact(s.num_rows)),
2280 expected_cardinality.clone()
2281 );
2282 assert_eq!(
2283 partial_join_stats.map(|s| s.column_statistics),
2284 expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat())
2285 );
2286 }
2287 Ok(())
2288 }
2289
2290 #[test]
2291 fn test_inner_join_cardinality_multiple_column() -> Result<()> {
2292 let left_col_stats = vec![
2293 create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2294 create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
2295 ];
2296
2297 let right_col_stats = vec![
2298 create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2299 create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
2300 ];
2301
2302 assert_eq!(
2305 estimate_inner_join_cardinality(
2306 Statistics {
2307 num_rows: Inexact(400),
2308 total_byte_size: Absent,
2309 column_statistics: left_col_stats,
2310 },
2311 Statistics {
2312 num_rows: Inexact(400),
2313 total_byte_size: Absent,
2314 column_statistics: right_col_stats,
2315 },
2316 ),
2317 Some(Inexact((400 * 400) / 200))
2318 );
2319 Ok(())
2320 }
2321
2322 #[test]
2323 fn test_inner_join_cardinality_decimal_range() -> Result<()> {
2324 let left_col_stats = vec![ColumnStatistics {
2325 distinct_count: Absent,
2326 min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)),
2327 max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)),
2328 ..Default::default()
2329 }];
2330
2331 let right_col_stats = vec![ColumnStatistics {
2332 distinct_count: Absent,
2333 min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)),
2334 max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)),
2335 ..Default::default()
2336 }];
2337
2338 assert_eq!(
2339 estimate_inner_join_cardinality(
2340 Statistics {
2341 num_rows: Inexact(100),
2342 total_byte_size: Absent,
2343 column_statistics: left_col_stats,
2344 },
2345 Statistics {
2346 num_rows: Inexact(100),
2347 total_byte_size: Absent,
2348 column_statistics: right_col_stats,
2349 },
2350 ),
2351 Some(Inexact(100))
2352 );
2353 Ok(())
2354 }
2355
2356 #[test]
2357 fn test_join_cardinality() -> Result<()> {
2358 let cases = vec![
2370 (JoinType::Inner, 800),
2371 (JoinType::Left, 1000),
2372 (JoinType::Right, 2000),
2373 (JoinType::Full, 2200),
2374 ];
2375
2376 let left_col_stats = vec![
2377 create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2378 create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2379 create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2380 ];
2381
2382 let right_col_stats = vec![
2383 create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2384 create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2385 create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2386 ];
2387
2388 for (join_type, expected_num_rows) in cases {
2389 let join_on = vec![
2390 (
2391 Arc::new(Column::new("a", 0)) as _,
2392 Arc::new(Column::new("c", 0)) as _,
2393 ),
2394 (
2395 Arc::new(Column::new("b", 1)) as _,
2396 Arc::new(Column::new("d", 1)) as _,
2397 ),
2398 ];
2399
2400 let partial_join_stats = estimate_join_cardinality(
2401 &join_type,
2402 create_stats(Some(1000), left_col_stats.clone(), false),
2403 create_stats(Some(2000), right_col_stats.clone(), false),
2404 &join_on,
2405 )
2406 .unwrap();
2407 assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2408 assert_eq!(
2409 partial_join_stats.column_statistics,
2410 [left_col_stats.clone(), right_col_stats.clone()].concat()
2411 );
2412 }
2413
2414 Ok(())
2415 }
2416
2417 #[test]
2418 fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
2419 let left_col_stats = vec![
2432 create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2433 create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2434 create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2435 ];
2436
2437 let right_col_stats = vec![
2438 create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2439 create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2440 create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2441 ];
2442
2443 let join_on = vec![
2444 (
2445 Arc::new(Column::new("a", 0)) as _,
2446 Arc::new(Column::new("c", 0)) as _,
2447 ),
2448 (
2449 Arc::new(Column::new("x", 2)) as _,
2450 Arc::new(Column::new("y", 2)) as _,
2451 ),
2452 ];
2453
2454 let cases = vec![
2455 (JoinType::Inner, 0),
2460 (JoinType::Left, 1000),
2463 (JoinType::Right, 2000),
2464 (JoinType::Full, 3000),
2468 ];
2469
2470 for (join_type, expected_num_rows) in cases {
2471 let partial_join_stats = estimate_join_cardinality(
2472 &join_type,
2473 create_stats(Some(1000), left_col_stats.clone(), true),
2474 create_stats(Some(2000), right_col_stats.clone(), true),
2475 &join_on,
2476 )
2477 .unwrap();
2478 assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2479 assert_eq!(
2480 partial_join_stats.column_statistics,
2481 [left_col_stats.clone(), right_col_stats.clone()].concat()
2482 );
2483 }
2484
2485 Ok(())
2486 }
2487
2488 #[test]
2489 fn test_anti_semi_join_cardinality() -> Result<()> {
2490 let cases: Vec<(JoinType, PartialStats, PartialStats, Option<usize>)> = vec![
2491 (
2501 JoinType::LeftSemi,
2502 (50, Inexact(10), Inexact(20), Absent, Absent),
2503 (10, Inexact(15), Inexact(25), Absent, Absent),
2504 Some(50),
2505 ),
2506 (
2507 JoinType::RightSemi,
2508 (50, Inexact(10), Inexact(20), Absent, Absent),
2509 (10, Inexact(15), Inexact(25), Absent, Absent),
2510 Some(10),
2511 ),
2512 (
2513 JoinType::LeftSemi,
2514 (10, Absent, Absent, Absent, Absent),
2515 (50, Absent, Absent, Absent, Absent),
2516 Some(10),
2517 ),
2518 (
2519 JoinType::LeftSemi,
2520 (50, Inexact(10), Inexact(20), Absent, Absent),
2521 (10, Inexact(30), Inexact(40), Absent, Absent),
2522 Some(0),
2523 ),
2524 (
2525 JoinType::LeftSemi,
2526 (50, Inexact(10), Absent, Absent, Absent),
2527 (10, Absent, Inexact(5), Absent, Absent),
2528 Some(0),
2529 ),
2530 (
2531 JoinType::LeftSemi,
2532 (50, Absent, Inexact(20), Absent, Absent),
2533 (10, Inexact(30), Absent, Absent, Absent),
2534 Some(0),
2535 ),
2536 (
2537 JoinType::LeftAnti,
2538 (50, Inexact(10), Inexact(20), Absent, Absent),
2539 (10, Inexact(15), Inexact(25), Absent, Absent),
2540 Some(50),
2541 ),
2542 (
2543 JoinType::RightAnti,
2544 (50, Inexact(10), Inexact(20), Absent, Absent),
2545 (10, Inexact(15), Inexact(25), Absent, Absent),
2546 Some(10),
2547 ),
2548 (
2549 JoinType::LeftAnti,
2550 (10, Absent, Absent, Absent, Absent),
2551 (50, Absent, Absent, Absent, Absent),
2552 Some(10),
2553 ),
2554 (
2555 JoinType::LeftAnti,
2556 (50, Inexact(10), Inexact(20), Absent, Absent),
2557 (10, Inexact(30), Inexact(40), Absent, Absent),
2558 Some(50),
2559 ),
2560 (
2561 JoinType::LeftAnti,
2562 (50, Inexact(10), Absent, Absent, Absent),
2563 (10, Absent, Inexact(5), Absent, Absent),
2564 Some(50),
2565 ),
2566 (
2567 JoinType::LeftAnti,
2568 (50, Absent, Inexact(20), Absent, Absent),
2569 (10, Inexact(30), Absent, Absent, Absent),
2570 Some(50),
2571 ),
2572 ];
2573
2574 let join_on = vec![(
2575 Arc::new(Column::new("l_col", 0)) as _,
2576 Arc::new(Column::new("r_col", 0)) as _,
2577 )];
2578
2579 for (join_type, outer_info, inner_info, expected) in cases {
2580 let outer_num_rows = outer_info.0;
2581 let outer_col_stats = vec![create_column_stats(
2582 outer_info.1,
2583 outer_info.2,
2584 outer_info.3,
2585 outer_info.4,
2586 )];
2587
2588 let inner_num_rows = inner_info.0;
2589 let inner_col_stats = vec![create_column_stats(
2590 inner_info.1,
2591 inner_info.2,
2592 inner_info.3,
2593 inner_info.4,
2594 )];
2595
2596 let output_cardinality = estimate_join_cardinality(
2597 &join_type,
2598 Statistics {
2599 num_rows: Inexact(outer_num_rows),
2600 total_byte_size: Absent,
2601 column_statistics: outer_col_stats,
2602 },
2603 Statistics {
2604 num_rows: Inexact(inner_num_rows),
2605 total_byte_size: Absent,
2606 column_statistics: inner_col_stats,
2607 },
2608 &join_on,
2609 )
2610 .map(|cardinality| cardinality.num_rows);
2611
2612 assert_eq!(
2613 output_cardinality, expected,
2614 "failure for join_type: {join_type}"
2615 );
2616 }
2617
2618 Ok(())
2619 }
2620
2621 #[test]
2622 fn test_semi_join_cardinality_absent_rows() -> Result<()> {
2623 let dummy_column_stats =
2624 vec![create_column_stats(Absent, Absent, Absent, Absent)];
2625 let join_on = vec![(
2626 Arc::new(Column::new("l_col", 0)) as _,
2627 Arc::new(Column::new("r_col", 0)) as _,
2628 )];
2629
2630 let absent_outer_estimation = estimate_join_cardinality(
2631 &JoinType::LeftSemi,
2632 Statistics {
2633 num_rows: Absent,
2634 total_byte_size: Absent,
2635 column_statistics: dummy_column_stats.clone(),
2636 },
2637 Statistics {
2638 num_rows: Exact(10),
2639 total_byte_size: Absent,
2640 column_statistics: dummy_column_stats.clone(),
2641 },
2642 &join_on,
2643 );
2644 assert!(
2645 absent_outer_estimation.is_none(),
2646 "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows"
2647 );
2648
2649 let absent_inner_estimation = estimate_join_cardinality(
2650 &JoinType::LeftSemi,
2651 Statistics {
2652 num_rows: Inexact(500),
2653 total_byte_size: Absent,
2654 column_statistics: dummy_column_stats.clone(),
2655 },
2656 Statistics {
2657 num_rows: Absent,
2658 total_byte_size: Absent,
2659 column_statistics: dummy_column_stats.clone(),
2660 },
2661 &join_on,
2662 ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows");
2663
2664 assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows");
2665
2666 let absent_inner_estimation = estimate_join_cardinality(
2667 &JoinType::LeftSemi,
2668 Statistics {
2669 num_rows: Absent,
2670 total_byte_size: Absent,
2671 column_statistics: dummy_column_stats.clone(),
2672 },
2673 Statistics {
2674 num_rows: Absent,
2675 total_byte_size: Absent,
2676 column_statistics: dummy_column_stats,
2677 },
2678 &join_on,
2679 );
2680 assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows");
2681
2682 Ok(())
2683 }
2684
2685 #[test]
2686 fn test_calculate_join_output_ordering() -> Result<()> {
2687 let left_ordering = LexOrdering::new(vec![
2688 PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2689 PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2690 PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2691 ]);
2692 let right_ordering = LexOrdering::new(vec![
2693 PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))),
2694 PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))),
2695 ]);
2696 let join_type = JoinType::Inner;
2697 let left_columns_len = 5;
2698 let maintains_input_orders = [[true, false], [false, true]];
2699 let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
2700
2701 let expected = [
2702 LexOrdering::new(vec![
2703 PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2704 PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2705 PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2706 PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2707 PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2708 ]),
2709 LexOrdering::new(vec![
2710 PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2711 PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2712 PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2713 PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2714 PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2715 ]),
2716 ];
2717
2718 for (i, (maintains_input_order, probe_side)) in
2719 maintains_input_orders.iter().zip(probe_sides).enumerate()
2720 {
2721 assert_eq!(
2722 calculate_join_output_ordering(
2723 left_ordering.as_ref(),
2724 right_ordering.as_ref(),
2725 join_type,
2726 left_columns_len,
2727 maintains_input_order,
2728 probe_side,
2729 )?,
2730 expected[i]
2731 );
2732 }
2733
2734 Ok(())
2735 }
2736
2737 fn create_test_batch(num_rows: usize) -> RecordBatch {
2738 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2739 let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32));
2740 RecordBatch::try_new(schema, vec![data]).unwrap()
2741 }
2742
2743 fn assert_split_batches(
2744 batches: Vec<(RecordBatch, bool)>,
2745 batch_size: usize,
2746 num_rows: usize,
2747 ) {
2748 let mut row_count = 0;
2749 for (batch, last) in batches.into_iter() {
2750 assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size));
2751 let column = batch
2752 .column(0)
2753 .as_any()
2754 .downcast_ref::<Int32Array>()
2755 .unwrap();
2756 for i in 0..batch.num_rows() {
2757 assert_eq!(column.value(i), i as i32 + row_count as i32);
2758 }
2759 row_count += batch.num_rows();
2760 assert_eq!(last, row_count == num_rows);
2761 }
2762 }
2763
2764 #[rstest]
2765 #[test]
2766 fn test_batch_splitter(
2767 #[values(1, 3, 11)] batch_size: usize,
2768 #[values(1, 6, 50)] num_rows: usize,
2769 ) {
2770 let mut splitter = BatchSplitter::new(batch_size);
2771 splitter.set_batch(create_test_batch(num_rows));
2772
2773 let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size));
2774 while let Some(batch) = splitter.next() {
2775 batches.push(batch);
2776 }
2777
2778 assert!(splitter.next().is_none());
2779 assert_split_batches(batches, batch_size, num_rows);
2780 }
2781
2782 #[tokio::test]
2783 async fn test_swap_reverting_projection() {
2784 let left_schema = Schema::new(vec![
2785 Field::new("a", DataType::Int32, false),
2786 Field::new("b", DataType::Int32, false),
2787 ]);
2788
2789 let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
2790
2791 let proj = swap_reverting_projection(&left_schema, &right_schema);
2792
2793 assert_eq!(proj.len(), 3);
2794
2795 let proj_expr = &proj[0];
2796 assert_eq!(proj_expr.alias, "a");
2797 assert_col_expr(&proj_expr.expr, "a", 1);
2798
2799 let proj_expr = &proj[1];
2800 assert_eq!(proj_expr.alias, "b");
2801 assert_col_expr(&proj_expr.expr, "b", 2);
2802
2803 let proj_expr = &proj[2];
2804 assert_eq!(proj_expr.alias, "c");
2805 assert_col_expr(&proj_expr.expr, "c", 0);
2806 }
2807
2808 fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
2809 let col = expr
2810 .as_any()
2811 .downcast_ref::<Column>()
2812 .expect("Projection items should be Column expression");
2813 assert_eq!(col.name(), name);
2814 assert_eq!(col.index(), index);
2815 }
2816
2817 #[test]
2818 fn test_join_metadata() -> Result<()> {
2819 let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)])
2820 .with_metadata(HashMap::from([("key".to_string(), "left".to_string())]));
2821
2822 let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)])
2823 .with_metadata(HashMap::from([("key".to_string(), "right".to_string())]));
2824
2825 let (join_schema, _) =
2826 build_join_schema(&left_schema, &right_schema, &JoinType::Left);
2827 assert_eq!(
2828 join_schema.metadata(),
2829 &HashMap::from([("key".to_string(), "left".to_string())])
2830 );
2831 let (join_schema, _) =
2832 build_join_schema(&left_schema, &right_schema, &JoinType::Right);
2833 assert_eq!(
2834 join_schema.metadata(),
2835 &HashMap::from([("key".to_string(), "right".to_string())])
2836 );
2837
2838 Ok(())
2839 }
2840}