1use std::cmp::min;
21use std::collections::HashSet;
22use std::fmt::{self, Debug};
23use std::future::Future;
24use std::iter::once;
25use std::ops::Range;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use crate::joins::SharedBitmapBuilder;
30use crate::metrics::{self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder};
31use crate::projection::{ProjectionExec, ProjectionExpr};
32use crate::{
33 ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics,
34};
35pub use super::join_filter::JoinFilter;
37pub use super::join_hash_map::JoinHashMapType;
38pub use crate::joins::{JoinOn, JoinOnRef};
39
40use ahash::RandomState;
41use arrow::array::{
42 builder::UInt64Builder, downcast_array, new_null_array, Array, ArrowPrimitiveType,
43 BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions,
44 UInt32Array, UInt32Builder, UInt64Array,
45};
46use arrow::array::{ArrayRef, BooleanArray};
47use arrow::buffer::{BooleanBuffer, NullBuffer};
48use arrow::compute::kernels::cmp::eq;
49use arrow::compute::{self, and, take, FilterBuilder};
50use arrow::datatypes::{
51 ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type,
52};
53use arrow_ord::cmp::not_distinct;
54use arrow_schema::ArrowError;
55use datafusion_common::cast::as_boolean_array;
56use datafusion_common::hash_utils::create_hashes;
57use datafusion_common::stats::Precision;
58use datafusion_common::{
59 plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult,
60};
61use datafusion_expr::interval_arithmetic::Interval;
62use datafusion_expr::Operator;
63use datafusion_physical_expr::expressions::Column;
64use datafusion_physical_expr::utils::collect_columns;
65use datafusion_physical_expr::{
66 add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr,
67 PhysicalExprRef,
68};
69
70use datafusion_physical_expr_common::datum::compare_op_for_nested;
71use futures::future::{BoxFuture, Shared};
72use futures::{ready, FutureExt};
73use parking_lot::Mutex;
74
75pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
78 let left: HashSet<Column> = left
79 .fields()
80 .iter()
81 .enumerate()
82 .map(|(idx, f)| Column::new(f.name(), idx))
83 .collect();
84 let right: HashSet<Column> = right
85 .fields()
86 .iter()
87 .enumerate()
88 .map(|(idx, f)| Column::new(f.name(), idx))
89 .collect();
90
91 check_join_set_is_valid(&left, &right, on)
92}
93
94fn check_join_set_is_valid(
97 left: &HashSet<Column>,
98 right: &HashSet<Column>,
99 on: &[(PhysicalExprRef, PhysicalExprRef)],
100) -> Result<()> {
101 let on_left = &on
102 .iter()
103 .flat_map(|on| collect_columns(&on.0))
104 .collect::<HashSet<_>>();
105 let left_missing = on_left.difference(left).collect::<HashSet<_>>();
106
107 let on_right = &on
108 .iter()
109 .flat_map(|on| collect_columns(&on.1))
110 .collect::<HashSet<_>>();
111 let right_missing = on_right.difference(right).collect::<HashSet<_>>();
112
113 if !left_missing.is_empty() | !right_missing.is_empty() {
114 return plan_err!(
115 "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}"
116 );
117 };
118
119 Ok(())
120}
121
122pub fn adjust_right_output_partitioning(
124 right_partitioning: &Partitioning,
125 left_columns_len: usize,
126) -> Result<Partitioning> {
127 let result = match right_partitioning {
128 Partitioning::Hash(exprs, size) => {
129 let new_exprs = exprs
130 .iter()
131 .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _))
132 .collect::<Result<_>>()?;
133 Partitioning::Hash(new_exprs, *size)
134 }
135 result => result.clone(),
136 };
137 Ok(result)
138}
139
140pub fn calculate_join_output_ordering(
142 left_ordering: Option<&LexOrdering>,
143 right_ordering: Option<&LexOrdering>,
144 join_type: JoinType,
145 left_columns_len: usize,
146 maintains_input_order: &[bool],
147 probe_side: Option<JoinSide>,
148) -> Result<Option<LexOrdering>> {
149 match maintains_input_order {
150 [true, false] => {
151 if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) {
153 if let Some(right_ordering) = right_ordering.cloned() {
154 let right_offset = add_offset_to_physical_sort_exprs(
155 right_ordering,
156 left_columns_len as _,
157 )?;
158 return if let Some(left_ordering) = left_ordering {
159 let mut result = left_ordering.clone();
160 result.extend(right_offset);
161 Ok(Some(result))
162 } else {
163 Ok(LexOrdering::new(right_offset))
164 };
165 }
166 }
167 Ok(left_ordering.cloned())
168 }
169 [false, true] => {
170 if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) {
172 return if let Some(right_ordering) = right_ordering.cloned() {
173 let mut right_offset = add_offset_to_physical_sort_exprs(
174 right_ordering,
175 left_columns_len as _,
176 )?;
177 if let Some(left_ordering) = left_ordering {
178 right_offset.extend(left_ordering.clone());
179 }
180 Ok(LexOrdering::new(right_offset))
181 } else {
182 Ok(left_ordering.cloned())
183 };
184 }
185 let Some(right_ordering) = right_ordering else {
186 return Ok(None);
187 };
188 match join_type {
189 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
190 add_offset_to_physical_sort_exprs(
191 right_ordering.clone(),
192 left_columns_len as _,
193 )
194 .map(LexOrdering::new)
195 }
196 _ => Ok(Some(right_ordering.clone())),
197 }
198 }
199 [false, false] => Ok(None),
201 [true, true] => unreachable!("Cannot maintain ordering of both sides"),
202 _ => unreachable!("Join operators can not have more than two children"),
203 }
204}
205
206#[derive(Debug, Clone, PartialEq)]
208pub struct ColumnIndex {
209 pub index: usize,
211 pub side: JoinSide,
213}
214
215fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
219 let force_nullable = match join_type {
220 JoinType::Inner => false,
221 JoinType::Left => !is_left, JoinType::Right => is_left, JoinType::Full => true, JoinType::LeftSemi => false, JoinType::RightSemi => false, JoinType::LeftAnti => false, JoinType::RightAnti => false, JoinType::LeftMark => false,
229 JoinType::RightMark => false,
230 };
231
232 if force_nullable {
233 old_field.clone().with_nullable(true)
234 } else {
235 old_field.clone()
236 }
237}
238
239pub fn build_join_schema(
242 left: &Schema,
243 right: &Schema,
244 join_type: &JoinType,
245) -> (Schema, Vec<ColumnIndex>) {
246 let left_fields = || {
247 left.fields()
248 .iter()
249 .map(|f| output_join_field(f, join_type, true))
250 .enumerate()
251 .map(|(index, f)| {
252 (
253 f,
254 ColumnIndex {
255 index,
256 side: JoinSide::Left,
257 },
258 )
259 })
260 };
261
262 let right_fields = || {
263 right
264 .fields()
265 .iter()
266 .map(|f| output_join_field(f, join_type, false))
267 .enumerate()
268 .map(|(index, f)| {
269 (
270 f,
271 ColumnIndex {
272 index,
273 side: JoinSide::Right,
274 },
275 )
276 })
277 };
278
279 let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
280 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
281 left_fields().chain(right_fields()).unzip()
283 }
284 JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(),
285 JoinType::LeftMark => {
286 let right_field = once((
287 Field::new("mark", arrow::datatypes::DataType::Boolean, false),
288 ColumnIndex {
289 index: 0,
290 side: JoinSide::None,
291 },
292 ));
293 left_fields().chain(right_field).unzip()
294 }
295 JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(),
296 JoinType::RightMark => {
297 let left_field = once((
298 Field::new("mark", arrow_schema::DataType::Boolean, false),
299 ColumnIndex {
300 index: 0,
301 side: JoinSide::None,
302 },
303 ));
304 right_fields().chain(left_field).unzip()
305 }
306 };
307
308 let (schema1, schema2) = match join_type {
309 JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right),
310 _ => (right, left),
311 };
312
313 let metadata = schema1
314 .metadata()
315 .clone()
316 .into_iter()
317 .chain(schema2.metadata().clone())
318 .collect();
319
320 (fields.finish().with_metadata(metadata), column_indices)
321}
322
323pub(crate) struct OnceAsync<T> {
337 fut: Mutex<Option<SharedResult<OnceFut<T>>>>,
338}
339
340impl<T> Default for OnceAsync<T> {
341 fn default() -> Self {
342 Self {
343 fut: Mutex::new(None),
344 }
345 }
346}
347
348impl<T> Debug for OnceAsync<T> {
349 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350 write!(f, "OnceAsync")
351 }
352}
353
354impl<T: 'static> OnceAsync<T> {
355 pub(crate) fn try_once<F, Fut>(&self, f: F) -> Result<OnceFut<T>>
363 where
364 F: FnOnce() -> Result<Fut>,
365 Fut: Future<Output = Result<T>> + Send + 'static,
366 {
367 self.fut
368 .lock()
369 .get_or_insert_with(|| f().map(OnceFut::new).map_err(Arc::new))
370 .clone()
371 .map_err(DataFusionError::Shared)
372 }
373}
374
375type OnceFutPending<T> = Shared<BoxFuture<'static, SharedResult<Arc<T>>>>;
377
378pub(crate) struct OnceFut<T> {
382 state: OnceFutState<T>,
383}
384
385impl<T> Clone for OnceFut<T> {
386 fn clone(&self) -> Self {
387 Self {
388 state: self.state.clone(),
389 }
390 }
391}
392
393#[derive(Clone, Debug, Default)]
396struct PartialJoinStatistics {
397 pub num_rows: usize,
398 pub column_statistics: Vec<ColumnStatistics>,
399}
400
401pub(crate) fn estimate_join_statistics(
403 left_stats: Statistics,
404 right_stats: Statistics,
405 on: JoinOn,
406 join_type: &JoinType,
407 schema: &Schema,
408) -> Result<Statistics> {
409 let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on);
410 let (num_rows, column_statistics) = match join_stats {
411 Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics),
412 None => (Precision::Absent, Statistics::unknown_column(schema)),
413 };
414 Ok(Statistics {
415 num_rows,
416 total_byte_size: Precision::Absent,
417 column_statistics,
418 })
419}
420
421fn estimate_join_cardinality(
423 join_type: &JoinType,
424 left_stats: Statistics,
425 right_stats: Statistics,
426 on: &JoinOn,
427) -> Option<PartialJoinStatistics> {
428 let (left_col_stats, right_col_stats) = on
429 .iter()
430 .map(|(left, right)| {
431 match (
432 left.as_any().downcast_ref::<Column>(),
433 right.as_any().downcast_ref::<Column>(),
434 ) {
435 (Some(left), Some(right)) => (
436 left_stats.column_statistics[left.index()].clone(),
437 right_stats.column_statistics[right.index()].clone(),
438 ),
439 _ => (
440 ColumnStatistics::new_unknown(),
441 ColumnStatistics::new_unknown(),
442 ),
443 }
444 })
445 .unzip::<_, _, Vec<_>, Vec<_>>();
446
447 match join_type {
448 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
449 let ij_cardinality = estimate_inner_join_cardinality(
450 Statistics {
451 num_rows: left_stats.num_rows,
452 total_byte_size: Precision::Absent,
453 column_statistics: left_col_stats,
454 },
455 Statistics {
456 num_rows: right_stats.num_rows,
457 total_byte_size: Precision::Absent,
458 column_statistics: right_col_stats,
459 },
460 )?;
461
462 let cardinality = match join_type {
467 JoinType::Inner => ij_cardinality,
468 JoinType::Left => ij_cardinality.max(&left_stats.num_rows),
469 JoinType::Right => ij_cardinality.max(&right_stats.num_rows),
470 JoinType::Full => ij_cardinality
471 .max(&left_stats.num_rows)
472 .add(&ij_cardinality.max(&right_stats.num_rows))
473 .sub(&ij_cardinality),
474 _ => unreachable!(),
475 };
476
477 Some(PartialJoinStatistics {
478 num_rows: *cardinality.get_value()?,
479 column_statistics: left_stats
484 .column_statistics
485 .into_iter()
486 .chain(right_stats.column_statistics)
487 .collect(),
488 })
489 }
490
491 JoinType::LeftSemi | JoinType::RightSemi => {
495 let (outer_stats, inner_stats) = match join_type {
496 JoinType::LeftSemi => (left_stats, right_stats),
497 _ => (right_stats, left_stats),
498 };
499 let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) {
500 Some(estimation) => *estimation.get_value()?,
501 None => *outer_stats.num_rows.get_value()?,
502 };
503
504 Some(PartialJoinStatistics {
505 num_rows: cardinality,
506 column_statistics: outer_stats.column_statistics,
507 })
508 }
509
510 JoinType::LeftAnti | JoinType::RightAnti => {
513 let outer_stats = match join_type {
514 JoinType::LeftAnti => left_stats,
515 _ => right_stats,
516 };
517
518 Some(PartialJoinStatistics {
519 num_rows: *outer_stats.num_rows.get_value()?,
520 column_statistics: outer_stats.column_statistics,
521 })
522 }
523
524 JoinType::LeftMark => {
525 let num_rows = *left_stats.num_rows.get_value()?;
526 let mut column_statistics = left_stats.column_statistics;
527 column_statistics.push(ColumnStatistics::new_unknown());
528 Some(PartialJoinStatistics {
529 num_rows,
530 column_statistics,
531 })
532 }
533 JoinType::RightMark => {
534 let num_rows = *right_stats.num_rows.get_value()?;
535 let mut column_statistics = right_stats.column_statistics;
536 column_statistics.push(ColumnStatistics::new_unknown());
537 Some(PartialJoinStatistics {
538 num_rows,
539 column_statistics,
540 })
541 }
542 }
543}
544
545fn estimate_inner_join_cardinality(
550 left_stats: Statistics,
551 right_stats: Statistics,
552) -> Option<Precision<usize>> {
553 if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) {
555 return Some(estimation);
556 };
557
558 let mut join_selectivity = Precision::Absent;
561 for (left_stat, right_stat) in left_stats
562 .column_statistics
563 .iter()
564 .zip(right_stats.column_statistics.iter())
565 {
566 if left_stat.min_value.get_value().is_none()
568 || left_stat.max_value.get_value().is_none()
569 || right_stat.min_value.get_value().is_none()
570 || right_stat.max_value.get_value().is_none()
571 {
572 return None;
573 }
574
575 let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat);
576 let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat);
577 let max_distinct = left_max_distinct.max(&right_max_distinct);
578 if max_distinct.get_value().is_some() {
579 join_selectivity = max_distinct;
583 }
584 }
585
586 let left_num_rows = left_stats.num_rows.get_value()?;
590 let right_num_rows = right_stats.num_rows.get_value()?;
591 match join_selectivity {
592 Precision::Exact(value) if value > 0 => {
593 Some(Precision::Exact((left_num_rows * right_num_rows) / value))
594 }
595 Precision::Inexact(value) if value > 0 => {
596 Some(Precision::Inexact((left_num_rows * right_num_rows) / value))
597 }
598 _ => None,
603 }
604}
605
606fn estimate_disjoint_inputs(
609 left_stats: &Statistics,
610 right_stats: &Statistics,
611) -> Option<Precision<usize>> {
612 for (left_stat, right_stat) in left_stats
613 .column_statistics
614 .iter()
615 .zip(right_stats.column_statistics.iter())
616 {
617 let left_min_val = left_stat.min_value.get_value();
621 let right_max_val = right_stat.max_value.get_value();
622 if left_min_val.is_some()
623 && right_max_val.is_some()
624 && left_min_val > right_max_val
625 {
626 return Some(
627 if left_stat.min_value.is_exact().unwrap_or(false)
628 && right_stat.max_value.is_exact().unwrap_or(false)
629 {
630 Precision::Exact(0)
631 } else {
632 Precision::Inexact(0)
633 },
634 );
635 }
636
637 let left_max_val = left_stat.max_value.get_value();
638 let right_min_val = right_stat.min_value.get_value();
639 if left_max_val.is_some()
640 && right_min_val.is_some()
641 && left_max_val < right_min_val
642 {
643 return Some(
644 if left_stat.max_value.is_exact().unwrap_or(false)
645 && right_stat.min_value.is_exact().unwrap_or(false)
646 {
647 Precision::Exact(0)
648 } else {
649 Precision::Inexact(0)
650 },
651 );
652 }
653 }
654
655 None
656}
657
658fn max_distinct_count(
663 num_rows: &Precision<usize>,
664 stats: &ColumnStatistics,
665) -> Precision<usize> {
666 match &stats.distinct_count {
667 &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc,
668 _ => {
669 let result = match num_rows {
672 Precision::Absent => Precision::Absent,
673 Precision::Inexact(count) => {
674 match count.checked_sub(*stats.null_count.get_value().unwrap_or(&0)) {
677 None => Precision::Inexact(0),
678 Some(non_null_count) => Precision::Inexact(non_null_count),
679 }
680 }
681 Precision::Exact(count) => {
682 let count = count - stats.null_count.get_value().unwrap_or(&0);
683 if stats.null_count.is_exact().unwrap_or(false) {
684 Precision::Exact(count)
685 } else {
686 Precision::Inexact(count)
687 }
688 }
689 };
690 if let (Some(min), Some(max)) =
692 (stats.min_value.get_value(), stats.max_value.get_value())
693 {
694 if let Some(range_dc) = Interval::try_new(min.clone(), max.clone())
695 .ok()
696 .and_then(|e| e.cardinality())
697 {
698 let range_dc = range_dc as usize;
699 return if matches!(result, Precision::Absent)
701 || &range_dc < result.get_value().unwrap()
702 {
703 if stats.min_value.is_exact().unwrap()
704 && stats.max_value.is_exact().unwrap()
705 {
706 Precision::Exact(range_dc)
707 } else {
708 Precision::Inexact(range_dc)
709 }
710 } else {
711 result
712 };
713 }
714 }
715
716 result
717 }
718 }
719}
720
721enum OnceFutState<T> {
722 Pending(OnceFutPending<T>),
723 Ready(SharedResult<Arc<T>>),
724}
725
726impl<T> Clone for OnceFutState<T> {
727 fn clone(&self) -> Self {
728 match self {
729 Self::Pending(p) => Self::Pending(p.clone()),
730 Self::Ready(r) => Self::Ready(r.clone()),
731 }
732 }
733}
734
735impl<T: 'static> OnceFut<T> {
736 pub(crate) fn new<Fut>(fut: Fut) -> Self
738 where
739 Fut: Future<Output = Result<T>> + Send + 'static,
740 {
741 Self {
742 state: OnceFutState::Pending(
743 fut.map(|res| res.map(Arc::new).map_err(Arc::new))
744 .boxed()
745 .shared(),
746 ),
747 }
748 }
749
750 pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll<Result<&T>> {
752 if let OnceFutState::Pending(fut) = &mut self.state {
753 let r = ready!(fut.poll_unpin(cx));
754 self.state = OnceFutState::Ready(r);
755 }
756
757 match &self.state {
759 OnceFutState::Pending(_) => unreachable!(),
760 OnceFutState::Ready(r) => Poll::Ready(
761 r.as_ref()
762 .map(|r| r.as_ref())
763 .map_err(DataFusionError::from),
764 ),
765 }
766 }
767
768 pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll<Result<Arc<T>>> {
770 if let OnceFutState::Pending(fut) = &mut self.state {
771 let r = ready!(fut.poll_unpin(cx));
772 self.state = OnceFutState::Ready(r);
773 }
774
775 match &self.state {
776 OnceFutState::Pending(_) => unreachable!(),
777 OnceFutState::Ready(r) => {
778 Poll::Ready(r.clone().map_err(DataFusionError::Shared))
779 }
780 }
781 }
782}
783
784pub(crate) fn need_produce_right_in_final(join_type: JoinType) -> bool {
791 matches!(
792 join_type,
793 JoinType::Full
794 | JoinType::Right
795 | JoinType::RightAnti
796 | JoinType::RightMark
797 | JoinType::RightSemi
798 )
799}
800
801pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
807 matches!(
808 join_type,
809 JoinType::Left
810 | JoinType::LeftAnti
811 | JoinType::LeftSemi
812 | JoinType::LeftMark
813 | JoinType::Full
814 )
815}
816
817pub(crate) fn get_final_indices_from_shared_bitmap(
818 shared_bitmap: &SharedBitmapBuilder,
819 join_type: JoinType,
820) -> (UInt64Array, UInt32Array) {
821 let bitmap = shared_bitmap.lock();
822 get_final_indices_from_bit_map(&bitmap, join_type)
823}
824
825pub(crate) fn get_final_indices_from_bit_map(
835 left_bit_map: &BooleanBufferBuilder,
836 join_type: JoinType,
837) -> (UInt64Array, UInt32Array) {
838 let left_size = left_bit_map.len();
839 if join_type == JoinType::LeftMark {
840 let left_indices = (0..left_size as u64).collect::<UInt64Array>();
841 let right_indices = (0..left_size)
842 .map(|idx| left_bit_map.get_bit(idx).then_some(0))
843 .collect::<UInt32Array>();
844 return (left_indices, right_indices);
845 }
846 let left_indices = if join_type == JoinType::LeftSemi {
847 (0..left_size)
848 .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
849 .collect::<UInt64Array>()
850 } else {
851 (0..left_size)
854 .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
855 .collect::<UInt64Array>()
856 };
857 let mut builder = UInt32Builder::with_capacity(left_indices.len());
860 builder.append_nulls(left_indices.len());
861 let right_indices = builder.finish();
862 (left_indices, right_indices)
863}
864
865pub(crate) fn apply_join_filter_to_indices(
866 build_input_buffer: &RecordBatch,
867 probe_batch: &RecordBatch,
868 build_indices: UInt64Array,
869 probe_indices: UInt32Array,
870 filter: &JoinFilter,
871 build_side: JoinSide,
872 max_intermediate_size: Option<usize>,
873) -> Result<(UInt64Array, UInt32Array)> {
874 if build_indices.is_empty() && probe_indices.is_empty() {
875 return Ok((build_indices, probe_indices));
876 };
877
878 let filter_result = if let Some(max_size) = max_intermediate_size {
879 let mut filter_results =
880 Vec::with_capacity(build_indices.len().div_ceil(max_size));
881
882 for i in (0..build_indices.len()).step_by(max_size) {
883 let end = min(build_indices.len(), i + max_size);
884 let len = end - i;
885 let intermediate_batch = build_batch_from_indices(
886 filter.schema(),
887 build_input_buffer,
888 probe_batch,
889 &build_indices.slice(i, len),
890 &probe_indices.slice(i, len),
891 filter.column_indices(),
892 build_side,
893 )?;
894 let filter_result = filter
895 .expression()
896 .evaluate(&intermediate_batch)?
897 .into_array(intermediate_batch.num_rows())?;
898 filter_results.push(filter_result);
899 }
900
901 let filter_refs: Vec<&dyn Array> =
902 filter_results.iter().map(|a| a.as_ref()).collect();
903
904 compute::concat(&filter_refs)?
905 } else {
906 let intermediate_batch = build_batch_from_indices(
907 filter.schema(),
908 build_input_buffer,
909 probe_batch,
910 &build_indices,
911 &probe_indices,
912 filter.column_indices(),
913 build_side,
914 )?;
915
916 filter
917 .expression()
918 .evaluate(&intermediate_batch)?
919 .into_array(intermediate_batch.num_rows())?
920 };
921
922 let mask = as_boolean_array(&filter_result)?;
923
924 let left_filtered = compute::filter(&build_indices, mask)?;
925 let right_filtered = compute::filter(&probe_indices, mask)?;
926 Ok((
927 downcast_array(left_filtered.as_ref()),
928 downcast_array(right_filtered.as_ref()),
929 ))
930}
931
932pub(crate) fn build_batch_from_indices(
935 schema: &Schema,
936 build_input_buffer: &RecordBatch,
937 probe_batch: &RecordBatch,
938 build_indices: &UInt64Array,
939 probe_indices: &UInt32Array,
940 column_indices: &[ColumnIndex],
941 build_side: JoinSide,
942) -> Result<RecordBatch> {
943 if schema.fields().is_empty() {
944 let options = RecordBatchOptions::new()
945 .with_match_field_names(true)
946 .with_row_count(Some(build_indices.len()));
947
948 return Ok(RecordBatch::try_new_with_options(
949 Arc::new(schema.clone()),
950 vec![],
951 &options,
952 )?);
953 }
954
955 let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
959
960 for column_index in column_indices {
961 let array = if column_index.side == JoinSide::None {
962 Arc::new(compute::is_not_null(probe_indices)?)
964 } else if column_index.side == build_side {
965 let array = build_input_buffer.column(column_index.index);
966 if array.is_empty() || build_indices.null_count() == build_indices.len() {
967 assert_eq!(build_indices.null_count(), build_indices.len());
971 new_null_array(array.data_type(), build_indices.len())
972 } else {
973 take(array.as_ref(), build_indices, None)?
974 }
975 } else {
976 let array = probe_batch.column(column_index.index);
977 if array.is_empty() || probe_indices.null_count() == probe_indices.len() {
978 assert_eq!(probe_indices.null_count(), probe_indices.len());
979 new_null_array(array.data_type(), probe_indices.len())
980 } else {
981 take(array.as_ref(), probe_indices, None)?
982 }
983 };
984
985 columns.push(array);
986 }
987 Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
988}
989
990pub(crate) fn build_batch_empty_build_side(
993 schema: &Schema,
994 build_batch: &RecordBatch,
995 probe_batch: &RecordBatch,
996 column_indices: &[ColumnIndex],
997 join_type: JoinType,
998) -> Result<RecordBatch> {
999 match join_type {
1000 JoinType::Inner
1003 | JoinType::Left
1004 | JoinType::LeftSemi
1005 | JoinType::RightSemi
1006 | JoinType::LeftAnti
1007 | JoinType::LeftMark => Ok(RecordBatch::new_empty(Arc::new(schema.clone()))),
1008
1009 JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => {
1011 let num_rows = probe_batch.num_rows();
1012 let mut columns: Vec<Arc<dyn Array>> =
1013 Vec::with_capacity(schema.fields().len());
1014
1015 for column_index in column_indices {
1016 let array = match column_index.side {
1017 JoinSide::Left => new_null_array(
1019 build_batch.column(column_index.index).data_type(),
1020 num_rows,
1021 ),
1022 JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)),
1024 JoinSide::None => Arc::new(BooleanArray::new(
1026 BooleanBuffer::new_unset(num_rows),
1027 None,
1028 )),
1029 };
1030
1031 columns.push(array);
1032 }
1033
1034 Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
1035 }
1036 }
1037}
1038
1039pub(crate) fn adjust_indices_by_join_type(
1042 left_indices: UInt64Array,
1043 right_indices: UInt32Array,
1044 adjust_range: Range<usize>,
1045 join_type: JoinType,
1046 preserve_order_for_right: bool,
1047) -> Result<(UInt64Array, UInt32Array)> {
1048 match join_type {
1049 JoinType::Inner => {
1050 Ok((left_indices, right_indices))
1052 }
1053 JoinType::Left => {
1054 Ok((left_indices, right_indices))
1056 }
1058 JoinType::Right => {
1059 append_right_indices(
1061 left_indices,
1062 right_indices,
1063 adjust_range,
1064 preserve_order_for_right,
1065 )
1066 }
1067 JoinType::Full => {
1068 append_right_indices(left_indices, right_indices, adjust_range, false)
1069 }
1070 JoinType::RightSemi => {
1071 let right_indices = get_semi_indices(adjust_range, &right_indices);
1073 Ok((left_indices, right_indices))
1075 }
1076 JoinType::RightAnti => {
1077 let right_indices = get_anti_indices(adjust_range, &right_indices);
1080 Ok((left_indices, right_indices))
1082 }
1083 JoinType::RightMark => {
1084 let right_indices = get_mark_indices(&adjust_range, &right_indices);
1085 let left_indices_vec: Vec<u64> = adjust_range.map(|i| i as u64).collect();
1086 let left_indices = UInt64Array::from(left_indices_vec);
1087 Ok((left_indices, right_indices))
1088 }
1089 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1090 Ok((
1093 UInt64Array::from_iter_values(vec![]),
1094 UInt32Array::from_iter_values(vec![]),
1095 ))
1096 }
1097 }
1098}
1099
1100pub(crate) fn append_right_indices(
1116 left_indices: UInt64Array,
1117 right_indices: UInt32Array,
1118 adjust_range: Range<usize>,
1119 preserve_order_for_right: bool,
1120) -> Result<(UInt64Array, UInt32Array)> {
1121 if preserve_order_for_right {
1122 Ok(append_probe_indices_in_order(
1123 left_indices,
1124 right_indices,
1125 adjust_range,
1126 ))
1127 } else {
1128 let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices);
1129
1130 if right_unmatched_indices.is_empty() {
1131 Ok((left_indices, right_indices))
1132 } else {
1133 let mut new_left_indices_builder =
1139 left_indices.into_builder().unwrap_or_else(|left_indices| {
1140 let mut builder = UInt64Builder::with_capacity(
1141 left_indices.len() + right_unmatched_indices.len(),
1142 );
1143 debug_assert_eq!(
1144 left_indices.null_count(),
1145 0,
1146 "expected left indices to have no nulls"
1147 );
1148 builder.append_slice(left_indices.values());
1149 builder
1150 });
1151 new_left_indices_builder.append_nulls(right_unmatched_indices.len());
1152 let new_left_indices = UInt64Array::from(new_left_indices_builder.finish());
1153
1154 let mut new_right_indices_builder = right_indices
1156 .into_builder()
1157 .unwrap_or_else(|right_indices| {
1158 let mut builder = UInt32Builder::with_capacity(
1159 right_indices.len() + right_unmatched_indices.len(),
1160 );
1161 debug_assert_eq!(
1162 right_indices.null_count(),
1163 0,
1164 "expected right indices to have no nulls"
1165 );
1166 builder.append_slice(right_indices.values());
1167 builder
1168 });
1169 debug_assert_eq!(
1170 right_unmatched_indices.null_count(),
1171 0,
1172 "expected right unmatched indices to have no nulls"
1173 );
1174 new_right_indices_builder.append_slice(right_unmatched_indices.values());
1175 let new_right_indices = UInt32Array::from(new_right_indices_builder.finish());
1176
1177 Ok((new_left_indices, new_right_indices))
1178 }
1179 }
1180}
1181
1182pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
1184 range: Range<usize>,
1185 input_indices: &PrimitiveArray<T>,
1186) -> PrimitiveArray<T>
1187where
1188 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1189{
1190 let bitmap = build_range_bitmap(&range, input_indices);
1191 let offset = range.start;
1192
1193 (range)
1195 .filter_map(|idx| {
1196 (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1197 })
1198 .collect()
1199}
1200
1201pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
1203 range: Range<usize>,
1204 input_indices: &PrimitiveArray<T>,
1205) -> PrimitiveArray<T>
1206where
1207 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1208{
1209 let bitmap = build_range_bitmap(&range, input_indices);
1210 let offset = range.start;
1211 (range)
1213 .filter_map(|idx| {
1214 (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1215 })
1216 .collect()
1217}
1218
1219pub(crate) fn get_mark_indices<T: ArrowPrimitiveType>(
1220 range: &Range<usize>,
1221 input_indices: &PrimitiveArray<T>,
1222) -> PrimitiveArray<UInt32Type>
1223where
1224 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1225{
1226 let mut bitmap = build_range_bitmap(range, input_indices);
1227 PrimitiveArray::new(
1228 vec![0; range.len()].into(),
1229 Some(NullBuffer::new(bitmap.finish())),
1230 )
1231}
1232
1233fn build_range_bitmap<T: ArrowPrimitiveType>(
1234 range: &Range<usize>,
1235 input: &PrimitiveArray<T>,
1236) -> BooleanBufferBuilder {
1237 let mut builder = BooleanBufferBuilder::new(range.len());
1238 builder.append_n(range.len(), false);
1239
1240 input.iter().flatten().for_each(|v| {
1241 let idx = v.as_usize();
1242 if range.contains(&idx) {
1243 builder.set_bit(idx - range.start, true);
1244 }
1245 });
1246
1247 builder
1248}
1249
1250fn append_probe_indices_in_order(
1268 build_indices: PrimitiveArray<UInt64Type>,
1269 probe_indices: PrimitiveArray<UInt32Type>,
1270 range: Range<usize>,
1271) -> (PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>) {
1272 let mut new_build_indices = UInt64Builder::new();
1274 let mut new_probe_indices = UInt32Builder::new();
1275 let mut prev_index = range.start as u32;
1277 debug_assert!(build_indices.len() == probe_indices.len());
1279 for (build_index, probe_index) in build_indices
1280 .values()
1281 .into_iter()
1282 .zip(probe_indices.values().into_iter())
1283 {
1284 for value in prev_index..*probe_index {
1286 new_probe_indices.append_value(value);
1287 new_build_indices.append_null();
1288 }
1289 new_probe_indices.append_value(*probe_index);
1291 new_build_indices.append_value(*build_index);
1292 prev_index = probe_index + 1;
1294 }
1295 for value in prev_index..range.end as u32 {
1297 new_probe_indices.append_value(value);
1298 new_build_indices.append_null();
1299 }
1300 (new_build_indices.finish(), new_probe_indices.finish())
1302}
1303
1304#[derive(Clone, Debug)]
1306pub(crate) struct BuildProbeJoinMetrics {
1307 pub(crate) baseline: BaselineMetrics,
1308 pub(crate) build_time: metrics::Time,
1310 pub(crate) build_input_batches: metrics::Count,
1312 pub(crate) build_input_rows: metrics::Count,
1314 pub(crate) build_mem_used: metrics::Gauge,
1316 pub(crate) join_time: metrics::Time,
1318 pub(crate) input_batches: metrics::Count,
1320 pub(crate) input_rows: metrics::Count,
1322 pub(crate) output_batches: metrics::Count,
1324}
1325
1326impl Drop for BuildProbeJoinMetrics {
1339 fn drop(&mut self) {
1340 self.baseline.elapsed_compute().add(&self.build_time);
1341 self.baseline.elapsed_compute().add(&self.join_time);
1342 }
1343}
1344
1345impl BuildProbeJoinMetrics {
1346 pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
1347 let baseline = BaselineMetrics::new(metrics, partition);
1348
1349 let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
1350
1351 let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition);
1352
1353 let build_input_batches =
1354 MetricBuilder::new(metrics).counter("build_input_batches", partition);
1355
1356 let build_input_rows =
1357 MetricBuilder::new(metrics).counter("build_input_rows", partition);
1358
1359 let build_mem_used =
1360 MetricBuilder::new(metrics).gauge("build_mem_used", partition);
1361
1362 let input_batches =
1363 MetricBuilder::new(metrics).counter("input_batches", partition);
1364
1365 let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
1366
1367 let output_batches =
1368 MetricBuilder::new(metrics).counter("output_batches", partition);
1369
1370 Self {
1371 build_time,
1372 build_input_batches,
1373 build_input_rows,
1374 build_mem_used,
1375 join_time,
1376 input_batches,
1377 input_rows,
1378 output_batches,
1379 baseline,
1380 }
1381 }
1382}
1383
1384#[macro_export]
1403macro_rules! handle_state {
1404 ($match_case:expr) => {
1405 match $match_case {
1406 Ok(StatefulStreamResult::Continue) => continue,
1407 Ok(StatefulStreamResult::Ready(result)) => {
1408 Poll::Ready(Ok(result).transpose())
1409 }
1410 Err(e) => Poll::Ready(Some(Err(e))),
1411 }
1412 };
1413}
1414
1415pub enum StatefulStreamResult<T> {
1427 Ready(T),
1428 Continue,
1429}
1430
1431pub(crate) fn symmetric_join_output_partitioning(
1432 left: &Arc<dyn ExecutionPlan>,
1433 right: &Arc<dyn ExecutionPlan>,
1434 join_type: &JoinType,
1435) -> Result<Partitioning> {
1436 let left_columns_len = left.schema().fields.len();
1437 let left_partitioning = left.output_partitioning();
1438 let right_partitioning = right.output_partitioning();
1439 let result = match join_type {
1440 JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1441 left_partitioning.clone()
1442 }
1443 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1444 right_partitioning.clone()
1445 }
1446 JoinType::Inner | JoinType::Right => {
1447 adjust_right_output_partitioning(right_partitioning, left_columns_len)?
1448 }
1449 JoinType::Full => {
1450 Partitioning::UnknownPartitioning(right_partitioning.partition_count())
1452 }
1453 };
1454 Ok(result)
1455}
1456
1457pub(crate) fn asymmetric_join_output_partitioning(
1458 left: &Arc<dyn ExecutionPlan>,
1459 right: &Arc<dyn ExecutionPlan>,
1460 join_type: &JoinType,
1461) -> Result<Partitioning> {
1462 let result = match join_type {
1463 JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
1464 right.output_partitioning(),
1465 left.schema().fields().len(),
1466 )?,
1467 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1468 right.output_partitioning().clone()
1469 }
1470 JoinType::Left
1471 | JoinType::LeftSemi
1472 | JoinType::LeftAnti
1473 | JoinType::Full
1474 | JoinType::LeftMark => Partitioning::UnknownPartitioning(
1475 right.output_partitioning().partition_count(),
1476 ),
1477 };
1478 Ok(result)
1479}
1480
1481pub(crate) trait BatchTransformer: Debug + Clone {
1486 fn set_batch(&mut self, batch: RecordBatch);
1488
1489 fn next(&mut self) -> Option<(RecordBatch, bool)>;
1493}
1494
1495#[derive(Debug, Clone)]
1496pub(crate) struct NoopBatchTransformer {
1498 batch: Option<RecordBatch>,
1500}
1501
1502impl NoopBatchTransformer {
1503 pub fn new() -> Self {
1504 Self { batch: None }
1505 }
1506}
1507
1508impl BatchTransformer for NoopBatchTransformer {
1509 fn set_batch(&mut self, batch: RecordBatch) {
1510 self.batch = Some(batch);
1511 }
1512
1513 fn next(&mut self) -> Option<(RecordBatch, bool)> {
1514 self.batch.take().map(|batch| (batch, true))
1515 }
1516}
1517
1518#[derive(Debug, Clone)]
1519pub(crate) struct BatchSplitter {
1521 batch: Option<RecordBatch>,
1523 batch_size: usize,
1525 row_index: usize,
1527}
1528
1529impl BatchSplitter {
1530 pub(crate) fn new(batch_size: usize) -> Self {
1532 Self {
1533 batch: None,
1534 batch_size,
1535 row_index: 0,
1536 }
1537 }
1538}
1539
1540impl BatchTransformer for BatchSplitter {
1541 fn set_batch(&mut self, batch: RecordBatch) {
1542 self.batch = Some(batch);
1543 self.row_index = 0;
1544 }
1545
1546 fn next(&mut self) -> Option<(RecordBatch, bool)> {
1547 let Some(batch) = &self.batch else {
1548 return None;
1549 };
1550
1551 let remaining_rows = batch.num_rows() - self.row_index;
1552 let rows_to_slice = remaining_rows.min(self.batch_size);
1553 let sliced_batch = batch.slice(self.row_index, rows_to_slice);
1554 self.row_index += rows_to_slice;
1555
1556 let mut last = false;
1557 if self.row_index >= batch.num_rows() {
1558 self.batch = None;
1559 last = true;
1560 }
1561
1562 Some((sliced_batch, last))
1563 }
1564}
1565
1566pub fn reorder_output_after_swap(
1573 plan: Arc<dyn ExecutionPlan>,
1574 left_schema: &Schema,
1575 right_schema: &Schema,
1576) -> Result<Arc<dyn ExecutionPlan>> {
1577 let proj = ProjectionExec::try_new(
1578 swap_reverting_projection(left_schema, right_schema),
1579 plan,
1580 )?;
1581 Ok(Arc::new(proj))
1582}
1583
1584fn swap_reverting_projection(
1590 left_schema: &Schema,
1591 right_schema: &Schema,
1592) -> Vec<ProjectionExpr> {
1593 let right_cols =
1594 right_schema
1595 .fields()
1596 .iter()
1597 .enumerate()
1598 .map(|(i, f)| ProjectionExpr {
1599 expr: Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
1600 alias: f.name().to_owned(),
1601 });
1602 let right_len = right_cols.len();
1603 let left_cols =
1604 left_schema
1605 .fields()
1606 .iter()
1607 .enumerate()
1608 .map(|(i, f)| ProjectionExpr {
1609 expr: Arc::new(Column::new(f.name(), right_len + i))
1610 as Arc<dyn PhysicalExpr>,
1611 alias: f.name().to_owned(),
1612 });
1613
1614 left_cols.chain(right_cols).collect()
1615}
1616
1617pub fn swap_join_projection(
1619 left_schema_len: usize,
1620 right_schema_len: usize,
1621 projection: Option<&Vec<usize>>,
1622 join_type: &JoinType,
1623) -> Option<Vec<usize>> {
1624 match join_type {
1625 JoinType::LeftAnti
1628 | JoinType::LeftSemi
1629 | JoinType::RightAnti
1630 | JoinType::RightSemi => projection.cloned(),
1631
1632 _ => projection.map(|p| {
1633 p.iter()
1634 .map(|i| {
1635 if *i < left_schema_len {
1640 *i + right_schema_len
1641 } else {
1642 *i - left_schema_len
1643 }
1644 })
1645 .collect()
1646 }),
1647 }
1648}
1649
1650#[allow(clippy::too_many_arguments)]
1657pub fn update_hash(
1658 on: &[PhysicalExprRef],
1659 batch: &RecordBatch,
1660 hash_map: &mut dyn JoinHashMapType,
1661 offset: usize,
1662 random_state: &RandomState,
1663 hashes_buffer: &mut Vec<u64>,
1664 deleted_offset: usize,
1665 fifo_hashmap: bool,
1666) -> Result<()> {
1667 let keys_values = on
1669 .iter()
1670 .map(|c| c.evaluate(batch)?.into_array(batch.num_rows()))
1671 .collect::<Result<Vec<_>>>()?;
1672
1673 let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1675
1676 hash_map.extend_zero(batch.num_rows());
1678
1679 let hash_values_iter = hash_values
1681 .iter()
1682 .enumerate()
1683 .map(|(i, val)| (i + offset, val));
1684
1685 if fifo_hashmap {
1686 hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset);
1687 } else {
1688 hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset);
1689 }
1690
1691 Ok(())
1692}
1693
1694pub(super) fn equal_rows_arr(
1695 indices_left: &UInt64Array,
1696 indices_right: &UInt32Array,
1697 left_arrays: &[ArrayRef],
1698 right_arrays: &[ArrayRef],
1699 null_equality: NullEquality,
1700) -> Result<(UInt64Array, UInt32Array)> {
1701 let mut iter = left_arrays.iter().zip(right_arrays.iter());
1702
1703 let Some((first_left, first_right)) = iter.next() else {
1704 return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into()));
1705 };
1706
1707 let arr_left = take(first_left.as_ref(), indices_left, None)?;
1708 let arr_right = take(first_right.as_ref(), indices_right, None)?;
1709
1710 let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?;
1711
1712 equal = iter
1716 .map(|(left, right)| {
1717 let arr_left = take(left.as_ref(), indices_left, None)?;
1718 let arr_right = take(right.as_ref(), indices_right, None)?;
1719 eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality)
1720 })
1721 .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
1722
1723 let filter_builder = FilterBuilder::new(&equal).optimize().build();
1724
1725 let left_filtered = filter_builder.filter(indices_left)?;
1726 let right_filtered = filter_builder.filter(indices_right)?;
1727
1728 Ok((
1729 downcast_array(left_filtered.as_ref()),
1730 downcast_array(right_filtered.as_ref()),
1731 ))
1732}
1733
1734fn eq_dyn_null(
1736 left: &dyn Array,
1737 right: &dyn Array,
1738 null_equality: NullEquality,
1739) -> Result<BooleanArray, ArrowError> {
1740 if left.data_type().is_nested() {
1744 let op = match null_equality {
1745 NullEquality::NullEqualsNothing => Operator::Eq,
1746 NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
1747 };
1748 return Ok(compare_op_for_nested(op, &left, &right)?);
1749 }
1750 match null_equality {
1751 NullEquality::NullEqualsNothing => eq(&left, &right),
1752 NullEquality::NullEqualsNull => not_distinct(&left, &right),
1753 }
1754}
1755
1756#[cfg(test)]
1757mod tests {
1758 use std::collections::HashMap;
1759 use std::pin::Pin;
1760
1761 use super::*;
1762
1763 use arrow::array::Int32Array;
1764 use arrow::datatypes::{DataType, Fields};
1765 use arrow::error::{ArrowError, Result as ArrowResult};
1766 use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
1767 use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
1768 use datafusion_physical_expr::PhysicalSortExpr;
1769
1770 use rstest::rstest;
1771
1772 fn check(
1773 left: &[Column],
1774 right: &[Column],
1775 on: &[(PhysicalExprRef, PhysicalExprRef)],
1776 ) -> Result<()> {
1777 let left = left
1778 .iter()
1779 .map(|x| x.to_owned())
1780 .collect::<HashSet<Column>>();
1781 let right = right
1782 .iter()
1783 .map(|x| x.to_owned())
1784 .collect::<HashSet<Column>>();
1785 check_join_set_is_valid(&left, &right, on)
1786 }
1787
1788 #[test]
1789 fn check_valid() -> Result<()> {
1790 let left = vec![Column::new("a", 0), Column::new("b1", 1)];
1791 let right = vec![Column::new("a", 0), Column::new("b2", 1)];
1792 let on = &[(
1793 Arc::new(Column::new("a", 0)) as _,
1794 Arc::new(Column::new("a", 0)) as _,
1795 )];
1796
1797 check(&left, &right, on)?;
1798 Ok(())
1799 }
1800
1801 #[test]
1802 fn check_not_in_right() {
1803 let left = vec![Column::new("a", 0), Column::new("b", 1)];
1804 let right = vec![Column::new("b", 0)];
1805 let on = &[(
1806 Arc::new(Column::new("a", 0)) as _,
1807 Arc::new(Column::new("a", 0)) as _,
1808 )];
1809
1810 assert!(check(&left, &right, on).is_err());
1811 }
1812
1813 #[tokio::test]
1814 async fn check_error_nesting() {
1815 let once_fut = OnceFut::<()>::new(async {
1816 arrow_err!(ArrowError::CsvError("some error".to_string()))
1817 });
1818
1819 struct TestFut(OnceFut<()>);
1820 impl Future for TestFut {
1821 type Output = ArrowResult<()>;
1822
1823 fn poll(
1824 mut self: Pin<&mut Self>,
1825 cx: &mut Context<'_>,
1826 ) -> Poll<Self::Output> {
1827 match ready!(self.0.get(cx)) {
1828 Ok(()) => Poll::Ready(Ok(())),
1829 Err(e) => Poll::Ready(Err(e.into())),
1830 }
1831 }
1832 }
1833
1834 let res = TestFut(once_fut).await;
1835 let arrow_err_from_fut = res.expect_err("once_fut always return error");
1836
1837 let wrapped_err = DataFusionError::from(arrow_err_from_fut);
1838 let root_err = wrapped_err.find_root();
1839
1840 let _expected =
1841 arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned()));
1842
1843 assert!(matches!(root_err, _expected))
1844 }
1845
1846 #[test]
1847 fn check_not_in_left() {
1848 let left = vec![Column::new("b", 0)];
1849 let right = vec![Column::new("a", 0)];
1850 let on = &[(
1851 Arc::new(Column::new("a", 0)) as _,
1852 Arc::new(Column::new("a", 0)) as _,
1853 )];
1854
1855 assert!(check(&left, &right, on).is_err());
1856 }
1857
1858 #[test]
1859 fn check_collision() {
1860 let left = vec![Column::new("a", 0), Column::new("c", 1)];
1862 let right = vec![Column::new("a", 0), Column::new("b", 1)];
1863 let on = &[(
1864 Arc::new(Column::new("a", 0)) as _,
1865 Arc::new(Column::new("b", 1)) as _,
1866 )];
1867
1868 assert!(check(&left, &right, on).is_ok());
1869 }
1870
1871 #[test]
1872 fn check_in_right() {
1873 let left = vec![Column::new("a", 0), Column::new("c", 1)];
1874 let right = vec![Column::new("b", 0)];
1875 let on = &[(
1876 Arc::new(Column::new("a", 0)) as _,
1877 Arc::new(Column::new("b", 0)) as _,
1878 )];
1879
1880 assert!(check(&left, &right, on).is_ok());
1881 }
1882
1883 #[test]
1884 fn test_join_schema() -> Result<()> {
1885 let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1886 let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1887 let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
1888 let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
1889
1890 let cases = vec![
1891 (&a, &b, JoinType::Inner, &a, &b),
1892 (&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
1893 (&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
1894 (&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
1895 (&a, &b, JoinType::Left, &a, &b_nulls),
1897 (&a, &b_nulls, JoinType::Left, &a, &b_nulls),
1898 (&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
1899 (&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
1900 (&a, &b, JoinType::Right, &a_nulls, &b),
1902 (&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
1903 (&a_nulls, &b, JoinType::Right, &a_nulls, &b),
1904 (&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
1905 (&a, &b, JoinType::Full, &a_nulls, &b_nulls),
1907 (&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
1908 (&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
1909 (&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
1910 ];
1911
1912 for (left_in, right_in, join_type, left_out, right_out) in cases {
1913 let (schema, _) = build_join_schema(left_in, right_in, &join_type);
1914
1915 let expected_fields = left_out
1916 .fields()
1917 .iter()
1918 .cloned()
1919 .chain(right_out.fields().iter().cloned())
1920 .collect::<Fields>();
1921
1922 let expected_schema = Schema::new(expected_fields);
1923 assert_eq!(
1924 schema,
1925 expected_schema,
1926 "Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
1927 left_in.fields()[0].name(),
1928 left_in.fields()[0].is_nullable(),
1929 right_in.fields()[0].name(),
1930 right_in.fields()[0].is_nullable(),
1931 join_type
1932 );
1933 }
1934
1935 Ok(())
1936 }
1937
1938 fn create_stats(
1939 num_rows: Option<usize>,
1940 column_stats: Vec<ColumnStatistics>,
1941 is_exact: bool,
1942 ) -> Statistics {
1943 Statistics {
1944 num_rows: if is_exact {
1945 num_rows.map(Exact)
1946 } else {
1947 num_rows.map(Inexact)
1948 }
1949 .unwrap_or(Absent),
1950 column_statistics: column_stats,
1951 total_byte_size: Absent,
1952 }
1953 }
1954
1955 fn create_column_stats(
1956 min: Precision<i64>,
1957 max: Precision<i64>,
1958 distinct_count: Precision<usize>,
1959 null_count: Precision<usize>,
1960 ) -> ColumnStatistics {
1961 ColumnStatistics {
1962 distinct_count,
1963 min_value: min.map(ScalarValue::from),
1964 max_value: max.map(ScalarValue::from),
1965 sum_value: Absent,
1966 null_count,
1967 }
1968 }
1969
1970 type PartialStats = (
1971 usize,
1972 Precision<i64>,
1973 Precision<i64>,
1974 Precision<usize>,
1975 Precision<usize>,
1976 );
1977
1978 #[test]
1982 fn test_inner_join_cardinality_single_column() -> Result<()> {
1983 let cases: Vec<(PartialStats, PartialStats, Option<Precision<usize>>)> = vec![
1984 (
1995 (10, Inexact(1), Inexact(10), Absent, Absent),
1996 (10, Inexact(1), Inexact(10), Absent, Absent),
1997 Some(Inexact(10)),
1998 ),
1999 (
2001 (10, Inexact(6), Inexact(10), Absent, Absent),
2002 (10, Inexact(8), Inexact(10), Absent, Absent),
2003 Some(Inexact(20)),
2004 ),
2005 (
2007 (10, Inexact(8), Inexact(10), Absent, Absent),
2008 (10, Inexact(6), Inexact(10), Absent, Absent),
2009 Some(Inexact(20)),
2010 ),
2011 (
2013 (10, Inexact(1), Inexact(15), Absent, Absent),
2014 (20, Inexact(1), Inexact(40), Absent, Absent),
2015 Some(Inexact(10)),
2016 ),
2017 (
2019 (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2020 (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2021 Some(Inexact(10)),
2022 ),
2023 (
2025 (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2026 (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2027 Some(Inexact(20)),
2028 ),
2029 (
2031 (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2032 (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2033 Some(Inexact(20)),
2034 ),
2035 (
2037 (10, Inexact(-5), Inexact(5), Absent, Absent),
2038 (10, Inexact(1), Inexact(5), Absent, Absent),
2039 Some(Inexact(10)),
2040 ),
2041 (
2043 (10, Inexact(-25), Inexact(-20), Absent, Absent),
2044 (10, Inexact(-25), Inexact(-15), Absent, Absent),
2045 Some(Inexact(10)),
2046 ),
2047 (
2052 (10, Inexact(-10), Inexact(0), Absent, Absent),
2053 (10, Inexact(0), Inexact(10), Inexact(5), Absent),
2054 Some(Inexact(10)),
2055 ),
2056 (
2058 (10, Inexact(1), Inexact(1), Absent, Absent),
2059 (10, Inexact(1), Inexact(1), Absent, Absent),
2060 Some(Inexact(100)),
2061 ),
2062 (
2068 (10, Absent, Absent, Absent, Absent),
2069 (10, Absent, Absent, Absent, Absent),
2070 None,
2071 ),
2072 (
2074 (10, Absent, Absent, Inexact(3), Absent),
2075 (10, Absent, Absent, Inexact(3), Absent),
2076 None,
2077 ),
2078 (
2079 (10, Inexact(2), Absent, Inexact(3), Absent),
2080 (10, Absent, Inexact(5), Inexact(3), Absent),
2081 None,
2082 ),
2083 (
2084 (10, Absent, Inexact(3), Inexact(3), Absent),
2085 (10, Inexact(1), Absent, Inexact(3), Absent),
2086 None,
2087 ),
2088 (
2089 (10, Absent, Inexact(3), Absent, Absent),
2090 (10, Inexact(1), Absent, Absent, Absent),
2091 None,
2092 ),
2093 (
2095 (10, Absent, Inexact(4), Absent, Absent),
2096 (10, Inexact(5), Absent, Absent, Absent),
2097 Some(Inexact(0)),
2098 ),
2099 (
2100 (10, Inexact(0), Inexact(10), Absent, Absent),
2101 (10, Inexact(11), Inexact(20), Absent, Absent),
2102 Some(Inexact(0)),
2103 ),
2104 (
2105 (10, Inexact(11), Inexact(20), Absent, Absent),
2106 (10, Inexact(0), Inexact(10), Absent, Absent),
2107 Some(Inexact(0)),
2108 ),
2109 (
2111 (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2112 (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2113 None,
2114 ),
2115 (
2117 (0, Inexact(1), Inexact(10), Absent, Exact(5)),
2118 (10, Inexact(1), Inexact(10), Absent, Absent),
2119 Some(Inexact(0)),
2120 ),
2121 ];
2122
2123 for (left_info, right_info, expected_cardinality) in cases {
2124 let left_num_rows = left_info.0;
2125 let left_col_stats = vec![create_column_stats(
2126 left_info.1,
2127 left_info.2,
2128 left_info.3,
2129 left_info.4,
2130 )];
2131
2132 let right_num_rows = right_info.0;
2133 let right_col_stats = vec![create_column_stats(
2134 right_info.1,
2135 right_info.2,
2136 right_info.3,
2137 right_info.4,
2138 )];
2139
2140 assert_eq!(
2141 estimate_inner_join_cardinality(
2142 Statistics {
2143 num_rows: Inexact(left_num_rows),
2144 total_byte_size: Absent,
2145 column_statistics: left_col_stats.clone(),
2146 },
2147 Statistics {
2148 num_rows: Inexact(right_num_rows),
2149 total_byte_size: Absent,
2150 column_statistics: right_col_stats.clone(),
2151 },
2152 ),
2153 expected_cardinality.clone()
2154 );
2155
2156 let join_type = JoinType::Inner;
2158 let join_on = vec![(
2159 Arc::new(Column::new("a", 0)) as _,
2160 Arc::new(Column::new("b", 0)) as _,
2161 )];
2162 let partial_join_stats = estimate_join_cardinality(
2163 &join_type,
2164 create_stats(Some(left_num_rows), left_col_stats.clone(), false),
2165 create_stats(Some(right_num_rows), right_col_stats.clone(), false),
2166 &join_on,
2167 );
2168
2169 assert_eq!(
2170 partial_join_stats.clone().map(|s| Inexact(s.num_rows)),
2171 expected_cardinality.clone()
2172 );
2173 assert_eq!(
2174 partial_join_stats.map(|s| s.column_statistics),
2175 expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat())
2176 );
2177 }
2178 Ok(())
2179 }
2180
2181 #[test]
2182 fn test_inner_join_cardinality_multiple_column() -> Result<()> {
2183 let left_col_stats = vec![
2184 create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2185 create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
2186 ];
2187
2188 let right_col_stats = vec![
2189 create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2190 create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
2191 ];
2192
2193 assert_eq!(
2196 estimate_inner_join_cardinality(
2197 Statistics {
2198 num_rows: Inexact(400),
2199 total_byte_size: Absent,
2200 column_statistics: left_col_stats,
2201 },
2202 Statistics {
2203 num_rows: Inexact(400),
2204 total_byte_size: Absent,
2205 column_statistics: right_col_stats,
2206 },
2207 ),
2208 Some(Inexact((400 * 400) / 200))
2209 );
2210 Ok(())
2211 }
2212
2213 #[test]
2214 fn test_inner_join_cardinality_decimal_range() -> Result<()> {
2215 let left_col_stats = vec![ColumnStatistics {
2216 distinct_count: Absent,
2217 min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)),
2218 max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)),
2219 ..Default::default()
2220 }];
2221
2222 let right_col_stats = vec![ColumnStatistics {
2223 distinct_count: Absent,
2224 min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)),
2225 max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)),
2226 ..Default::default()
2227 }];
2228
2229 assert_eq!(
2230 estimate_inner_join_cardinality(
2231 Statistics {
2232 num_rows: Inexact(100),
2233 total_byte_size: Absent,
2234 column_statistics: left_col_stats,
2235 },
2236 Statistics {
2237 num_rows: Inexact(100),
2238 total_byte_size: Absent,
2239 column_statistics: right_col_stats,
2240 },
2241 ),
2242 Some(Inexact(100))
2243 );
2244 Ok(())
2245 }
2246
2247 #[test]
2248 fn test_join_cardinality() -> Result<()> {
2249 let cases = vec![
2261 (JoinType::Inner, 800),
2262 (JoinType::Left, 1000),
2263 (JoinType::Right, 2000),
2264 (JoinType::Full, 2200),
2265 ];
2266
2267 let left_col_stats = vec![
2268 create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2269 create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2270 create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2271 ];
2272
2273 let right_col_stats = vec![
2274 create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2275 create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2276 create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2277 ];
2278
2279 for (join_type, expected_num_rows) in cases {
2280 let join_on = vec![
2281 (
2282 Arc::new(Column::new("a", 0)) as _,
2283 Arc::new(Column::new("c", 0)) as _,
2284 ),
2285 (
2286 Arc::new(Column::new("b", 1)) as _,
2287 Arc::new(Column::new("d", 1)) as _,
2288 ),
2289 ];
2290
2291 let partial_join_stats = estimate_join_cardinality(
2292 &join_type,
2293 create_stats(Some(1000), left_col_stats.clone(), false),
2294 create_stats(Some(2000), right_col_stats.clone(), false),
2295 &join_on,
2296 )
2297 .unwrap();
2298 assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2299 assert_eq!(
2300 partial_join_stats.column_statistics,
2301 [left_col_stats.clone(), right_col_stats.clone()].concat()
2302 );
2303 }
2304
2305 Ok(())
2306 }
2307
2308 #[test]
2309 fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
2310 let left_col_stats = vec![
2323 create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2324 create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2325 create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2326 ];
2327
2328 let right_col_stats = vec![
2329 create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2330 create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2331 create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2332 ];
2333
2334 let join_on = vec![
2335 (
2336 Arc::new(Column::new("a", 0)) as _,
2337 Arc::new(Column::new("c", 0)) as _,
2338 ),
2339 (
2340 Arc::new(Column::new("x", 2)) as _,
2341 Arc::new(Column::new("y", 2)) as _,
2342 ),
2343 ];
2344
2345 let cases = vec![
2346 (JoinType::Inner, 0),
2351 (JoinType::Left, 1000),
2354 (JoinType::Right, 2000),
2355 (JoinType::Full, 3000),
2359 ];
2360
2361 for (join_type, expected_num_rows) in cases {
2362 let partial_join_stats = estimate_join_cardinality(
2363 &join_type,
2364 create_stats(Some(1000), left_col_stats.clone(), true),
2365 create_stats(Some(2000), right_col_stats.clone(), true),
2366 &join_on,
2367 )
2368 .unwrap();
2369 assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2370 assert_eq!(
2371 partial_join_stats.column_statistics,
2372 [left_col_stats.clone(), right_col_stats.clone()].concat()
2373 );
2374 }
2375
2376 Ok(())
2377 }
2378
2379 #[test]
2380 fn test_anti_semi_join_cardinality() -> Result<()> {
2381 let cases: Vec<(JoinType, PartialStats, PartialStats, Option<usize>)> = vec![
2382 (
2392 JoinType::LeftSemi,
2393 (50, Inexact(10), Inexact(20), Absent, Absent),
2394 (10, Inexact(15), Inexact(25), Absent, Absent),
2395 Some(50),
2396 ),
2397 (
2398 JoinType::RightSemi,
2399 (50, Inexact(10), Inexact(20), Absent, Absent),
2400 (10, Inexact(15), Inexact(25), Absent, Absent),
2401 Some(10),
2402 ),
2403 (
2404 JoinType::LeftSemi,
2405 (10, Absent, Absent, Absent, Absent),
2406 (50, Absent, Absent, Absent, Absent),
2407 Some(10),
2408 ),
2409 (
2410 JoinType::LeftSemi,
2411 (50, Inexact(10), Inexact(20), Absent, Absent),
2412 (10, Inexact(30), Inexact(40), Absent, Absent),
2413 Some(0),
2414 ),
2415 (
2416 JoinType::LeftSemi,
2417 (50, Inexact(10), Absent, Absent, Absent),
2418 (10, Absent, Inexact(5), Absent, Absent),
2419 Some(0),
2420 ),
2421 (
2422 JoinType::LeftSemi,
2423 (50, Absent, Inexact(20), Absent, Absent),
2424 (10, Inexact(30), Absent, Absent, Absent),
2425 Some(0),
2426 ),
2427 (
2428 JoinType::LeftAnti,
2429 (50, Inexact(10), Inexact(20), Absent, Absent),
2430 (10, Inexact(15), Inexact(25), Absent, Absent),
2431 Some(50),
2432 ),
2433 (
2434 JoinType::RightAnti,
2435 (50, Inexact(10), Inexact(20), Absent, Absent),
2436 (10, Inexact(15), Inexact(25), Absent, Absent),
2437 Some(10),
2438 ),
2439 (
2440 JoinType::LeftAnti,
2441 (10, Absent, Absent, Absent, Absent),
2442 (50, Absent, Absent, Absent, Absent),
2443 Some(10),
2444 ),
2445 (
2446 JoinType::LeftAnti,
2447 (50, Inexact(10), Inexact(20), Absent, Absent),
2448 (10, Inexact(30), Inexact(40), Absent, Absent),
2449 Some(50),
2450 ),
2451 (
2452 JoinType::LeftAnti,
2453 (50, Inexact(10), Absent, Absent, Absent),
2454 (10, Absent, Inexact(5), Absent, Absent),
2455 Some(50),
2456 ),
2457 (
2458 JoinType::LeftAnti,
2459 (50, Absent, Inexact(20), Absent, Absent),
2460 (10, Inexact(30), Absent, Absent, Absent),
2461 Some(50),
2462 ),
2463 ];
2464
2465 let join_on = vec![(
2466 Arc::new(Column::new("l_col", 0)) as _,
2467 Arc::new(Column::new("r_col", 0)) as _,
2468 )];
2469
2470 for (join_type, outer_info, inner_info, expected) in cases {
2471 let outer_num_rows = outer_info.0;
2472 let outer_col_stats = vec![create_column_stats(
2473 outer_info.1,
2474 outer_info.2,
2475 outer_info.3,
2476 outer_info.4,
2477 )];
2478
2479 let inner_num_rows = inner_info.0;
2480 let inner_col_stats = vec![create_column_stats(
2481 inner_info.1,
2482 inner_info.2,
2483 inner_info.3,
2484 inner_info.4,
2485 )];
2486
2487 let output_cardinality = estimate_join_cardinality(
2488 &join_type,
2489 Statistics {
2490 num_rows: Inexact(outer_num_rows),
2491 total_byte_size: Absent,
2492 column_statistics: outer_col_stats,
2493 },
2494 Statistics {
2495 num_rows: Inexact(inner_num_rows),
2496 total_byte_size: Absent,
2497 column_statistics: inner_col_stats,
2498 },
2499 &join_on,
2500 )
2501 .map(|cardinality| cardinality.num_rows);
2502
2503 assert_eq!(
2504 output_cardinality, expected,
2505 "failure for join_type: {join_type}"
2506 );
2507 }
2508
2509 Ok(())
2510 }
2511
2512 #[test]
2513 fn test_semi_join_cardinality_absent_rows() -> Result<()> {
2514 let dummy_column_stats =
2515 vec![create_column_stats(Absent, Absent, Absent, Absent)];
2516 let join_on = vec![(
2517 Arc::new(Column::new("l_col", 0)) as _,
2518 Arc::new(Column::new("r_col", 0)) as _,
2519 )];
2520
2521 let absent_outer_estimation = estimate_join_cardinality(
2522 &JoinType::LeftSemi,
2523 Statistics {
2524 num_rows: Absent,
2525 total_byte_size: Absent,
2526 column_statistics: dummy_column_stats.clone(),
2527 },
2528 Statistics {
2529 num_rows: Exact(10),
2530 total_byte_size: Absent,
2531 column_statistics: dummy_column_stats.clone(),
2532 },
2533 &join_on,
2534 );
2535 assert!(
2536 absent_outer_estimation.is_none(),
2537 "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows"
2538 );
2539
2540 let absent_inner_estimation = estimate_join_cardinality(
2541 &JoinType::LeftSemi,
2542 Statistics {
2543 num_rows: Inexact(500),
2544 total_byte_size: Absent,
2545 column_statistics: dummy_column_stats.clone(),
2546 },
2547 Statistics {
2548 num_rows: Absent,
2549 total_byte_size: Absent,
2550 column_statistics: dummy_column_stats.clone(),
2551 },
2552 &join_on,
2553 ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows");
2554
2555 assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows");
2556
2557 let absent_inner_estimation = estimate_join_cardinality(
2558 &JoinType::LeftSemi,
2559 Statistics {
2560 num_rows: Absent,
2561 total_byte_size: Absent,
2562 column_statistics: dummy_column_stats.clone(),
2563 },
2564 Statistics {
2565 num_rows: Absent,
2566 total_byte_size: Absent,
2567 column_statistics: dummy_column_stats,
2568 },
2569 &join_on,
2570 );
2571 assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows");
2572
2573 Ok(())
2574 }
2575
2576 #[test]
2577 fn test_calculate_join_output_ordering() -> Result<()> {
2578 let left_ordering = LexOrdering::new(vec![
2579 PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2580 PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2581 PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2582 ]);
2583 let right_ordering = LexOrdering::new(vec![
2584 PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))),
2585 PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))),
2586 ]);
2587 let join_type = JoinType::Inner;
2588 let left_columns_len = 5;
2589 let maintains_input_orders = [[true, false], [false, true]];
2590 let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
2591
2592 let expected = [
2593 LexOrdering::new(vec![
2594 PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2595 PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2596 PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2597 PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2598 PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2599 ]),
2600 LexOrdering::new(vec![
2601 PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2602 PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2603 PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2604 PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2605 PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2606 ]),
2607 ];
2608
2609 for (i, (maintains_input_order, probe_side)) in
2610 maintains_input_orders.iter().zip(probe_sides).enumerate()
2611 {
2612 assert_eq!(
2613 calculate_join_output_ordering(
2614 left_ordering.as_ref(),
2615 right_ordering.as_ref(),
2616 join_type,
2617 left_columns_len,
2618 maintains_input_order,
2619 probe_side,
2620 )?,
2621 expected[i]
2622 );
2623 }
2624
2625 Ok(())
2626 }
2627
2628 fn create_test_batch(num_rows: usize) -> RecordBatch {
2629 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2630 let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32));
2631 RecordBatch::try_new(schema, vec![data]).unwrap()
2632 }
2633
2634 fn assert_split_batches(
2635 batches: Vec<(RecordBatch, bool)>,
2636 batch_size: usize,
2637 num_rows: usize,
2638 ) {
2639 let mut row_count = 0;
2640 for (batch, last) in batches.into_iter() {
2641 assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size));
2642 let column = batch
2643 .column(0)
2644 .as_any()
2645 .downcast_ref::<Int32Array>()
2646 .unwrap();
2647 for i in 0..batch.num_rows() {
2648 assert_eq!(column.value(i), i as i32 + row_count as i32);
2649 }
2650 row_count += batch.num_rows();
2651 assert_eq!(last, row_count == num_rows);
2652 }
2653 }
2654
2655 #[rstest]
2656 #[test]
2657 fn test_batch_splitter(
2658 #[values(1, 3, 11)] batch_size: usize,
2659 #[values(1, 6, 50)] num_rows: usize,
2660 ) {
2661 let mut splitter = BatchSplitter::new(batch_size);
2662 splitter.set_batch(create_test_batch(num_rows));
2663
2664 let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size));
2665 while let Some(batch) = splitter.next() {
2666 batches.push(batch);
2667 }
2668
2669 assert!(splitter.next().is_none());
2670 assert_split_batches(batches, batch_size, num_rows);
2671 }
2672
2673 #[tokio::test]
2674 async fn test_swap_reverting_projection() {
2675 let left_schema = Schema::new(vec![
2676 Field::new("a", DataType::Int32, false),
2677 Field::new("b", DataType::Int32, false),
2678 ]);
2679
2680 let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
2681
2682 let proj = swap_reverting_projection(&left_schema, &right_schema);
2683
2684 assert_eq!(proj.len(), 3);
2685
2686 let proj_expr = &proj[0];
2687 assert_eq!(proj_expr.alias, "a");
2688 assert_col_expr(&proj_expr.expr, "a", 1);
2689
2690 let proj_expr = &proj[1];
2691 assert_eq!(proj_expr.alias, "b");
2692 assert_col_expr(&proj_expr.expr, "b", 2);
2693
2694 let proj_expr = &proj[2];
2695 assert_eq!(proj_expr.alias, "c");
2696 assert_col_expr(&proj_expr.expr, "c", 0);
2697 }
2698
2699 fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
2700 let col = expr
2701 .as_any()
2702 .downcast_ref::<Column>()
2703 .expect("Projection items should be Column expression");
2704 assert_eq!(col.name(), name);
2705 assert_eq!(col.index(), index);
2706 }
2707
2708 #[test]
2709 fn test_join_metadata() -> Result<()> {
2710 let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)])
2711 .with_metadata(HashMap::from([("key".to_string(), "left".to_string())]));
2712
2713 let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)])
2714 .with_metadata(HashMap::from([("key".to_string(), "right".to_string())]));
2715
2716 let (join_schema, _) =
2717 build_join_schema(&left_schema, &right_schema, &JoinType::Left);
2718 assert_eq!(
2719 join_schema.metadata(),
2720 &HashMap::from([("key".to_string(), "left".to_string())])
2721 );
2722 let (join_schema, _) =
2723 build_join_schema(&left_schema, &right_schema, &JoinType::Right);
2724 assert_eq!(
2725 join_schema.metadata(),
2726 &HashMap::from([("key".to_string(), "right".to_string())])
2727 );
2728
2729 Ok(())
2730 }
2731}