datafusion_physical_plan/joins/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Join related functionality used both on logical and physical plans
19
20use std::cmp::{min, Ordering};
21use std::collections::HashSet;
22use std::fmt::{self, Debug};
23use std::future::Future;
24use std::iter::once;
25use std::ops::Range;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use crate::joins::SharedBitmapBuilder;
30use crate::metrics::{self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder};
31use crate::projection::{ProjectionExec, ProjectionExpr};
32use crate::{
33    ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics,
34};
35// compatibility
36pub use super::join_filter::JoinFilter;
37pub use super::join_hash_map::JoinHashMapType;
38pub use crate::joins::{JoinOn, JoinOnRef};
39
40use ahash::RandomState;
41use arrow::array::{
42    builder::UInt64Builder, downcast_array, new_null_array, Array, ArrowPrimitiveType,
43    BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions,
44    UInt32Array, UInt32Builder, UInt64Array,
45};
46use arrow::array::{
47    ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
48    Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array,
49    Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray,
50    StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray,
51    TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array,
52};
53use arrow::buffer::{BooleanBuffer, NullBuffer};
54use arrow::compute::kernels::cmp::eq;
55use arrow::compute::{self, and, take, FilterBuilder};
56use arrow::datatypes::{
57    ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type,
58};
59use arrow_ord::cmp::not_distinct;
60use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit};
61use datafusion_common::cast::as_boolean_array;
62use datafusion_common::hash_utils::create_hashes;
63use datafusion_common::stats::Precision;
64use datafusion_common::{
65    not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result,
66    SharedResult,
67};
68use datafusion_expr::interval_arithmetic::Interval;
69use datafusion_expr::Operator;
70use datafusion_physical_expr::expressions::Column;
71use datafusion_physical_expr::utils::collect_columns;
72use datafusion_physical_expr::{
73    add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr,
74    PhysicalExprRef,
75};
76
77use datafusion_physical_expr_common::datum::compare_op_for_nested;
78use futures::future::{BoxFuture, Shared};
79use futures::{ready, FutureExt};
80use parking_lot::Mutex;
81
82/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join.
83/// They are valid whenever their columns' intersection equals the set `on`
84pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
85    let left: HashSet<Column> = left
86        .fields()
87        .iter()
88        .enumerate()
89        .map(|(idx, f)| Column::new(f.name(), idx))
90        .collect();
91    let right: HashSet<Column> = right
92        .fields()
93        .iter()
94        .enumerate()
95        .map(|(idx, f)| Column::new(f.name(), idx))
96        .collect();
97
98    check_join_set_is_valid(&left, &right, on)
99}
100
101/// Checks whether the sets left, right and on compose a valid join.
102/// They are valid whenever their intersection equals the set `on`
103fn check_join_set_is_valid(
104    left: &HashSet<Column>,
105    right: &HashSet<Column>,
106    on: &[(PhysicalExprRef, PhysicalExprRef)],
107) -> Result<()> {
108    let on_left = &on
109        .iter()
110        .flat_map(|on| collect_columns(&on.0))
111        .collect::<HashSet<_>>();
112    let left_missing = on_left.difference(left).collect::<HashSet<_>>();
113
114    let on_right = &on
115        .iter()
116        .flat_map(|on| collect_columns(&on.1))
117        .collect::<HashSet<_>>();
118    let right_missing = on_right.difference(right).collect::<HashSet<_>>();
119
120    if !left_missing.is_empty() | !right_missing.is_empty() {
121        return plan_err!(
122            "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}"
123        );
124    };
125
126    Ok(())
127}
128
129/// Adjust the right out partitioning to new Column Index
130pub fn adjust_right_output_partitioning(
131    right_partitioning: &Partitioning,
132    left_columns_len: usize,
133) -> Result<Partitioning> {
134    let result = match right_partitioning {
135        Partitioning::Hash(exprs, size) => {
136            let new_exprs = exprs
137                .iter()
138                .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _))
139                .collect::<Result<_>>()?;
140            Partitioning::Hash(new_exprs, *size)
141        }
142        result => result.clone(),
143    };
144    Ok(result)
145}
146
147/// Calculate the output ordering of a given join operation.
148pub fn calculate_join_output_ordering(
149    left_ordering: Option<&LexOrdering>,
150    right_ordering: Option<&LexOrdering>,
151    join_type: JoinType,
152    left_columns_len: usize,
153    maintains_input_order: &[bool],
154    probe_side: Option<JoinSide>,
155) -> Result<Option<LexOrdering>> {
156    match maintains_input_order {
157        [true, false] => {
158            // Special case, we can prefix ordering of right side with the ordering of left side.
159            if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) {
160                if let Some(right_ordering) = right_ordering.cloned() {
161                    let right_offset = add_offset_to_physical_sort_exprs(
162                        right_ordering,
163                        left_columns_len as _,
164                    )?;
165                    return if let Some(left_ordering) = left_ordering {
166                        let mut result = left_ordering.clone();
167                        result.extend(right_offset);
168                        Ok(Some(result))
169                    } else {
170                        Ok(LexOrdering::new(right_offset))
171                    };
172                }
173            }
174            Ok(left_ordering.cloned())
175        }
176        [false, true] => {
177            // Special case, we can prefix ordering of left side with the ordering of right side.
178            if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) {
179                return if let Some(right_ordering) = right_ordering.cloned() {
180                    let mut right_offset = add_offset_to_physical_sort_exprs(
181                        right_ordering,
182                        left_columns_len as _,
183                    )?;
184                    if let Some(left_ordering) = left_ordering {
185                        right_offset.extend(left_ordering.clone());
186                    }
187                    Ok(LexOrdering::new(right_offset))
188                } else {
189                    Ok(left_ordering.cloned())
190                };
191            }
192            let Some(right_ordering) = right_ordering else {
193                return Ok(None);
194            };
195            match join_type {
196                JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
197                    add_offset_to_physical_sort_exprs(
198                        right_ordering.clone(),
199                        left_columns_len as _,
200                    )
201                    .map(LexOrdering::new)
202                }
203                _ => Ok(Some(right_ordering.clone())),
204            }
205        }
206        // Doesn't maintain ordering, output ordering is None.
207        [false, false] => Ok(None),
208        [true, true] => unreachable!("Cannot maintain ordering of both sides"),
209        _ => unreachable!("Join operators can not have more than two children"),
210    }
211}
212
213/// Information about the index and placement (left or right) of the columns
214#[derive(Debug, Clone, PartialEq)]
215pub struct ColumnIndex {
216    /// Index of the column
217    pub index: usize,
218    /// Whether the column is at the left or right side
219    pub side: JoinSide,
220}
221
222/// Returns the output field given the input field. Outer joins may
223/// insert nulls even if the input was not null
224fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
225    let force_nullable = match join_type {
226        JoinType::Inner => false,
227        JoinType::Left => !is_left, // right input is padded with nulls
228        JoinType::Right => is_left, // left input is padded with nulls
229        JoinType::Full => true,     // both inputs can be padded with nulls
230        JoinType::LeftSemi => false, // doesn't introduce nulls
231        JoinType::RightSemi => false, // doesn't introduce nulls
232        JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??)
233        JoinType::RightAnti => false, // doesn't introduce nulls (or can it??)
234        JoinType::LeftMark => false,
235        JoinType::RightMark => false,
236    };
237
238    if force_nullable {
239        old_field.clone().with_nullable(true)
240    } else {
241        old_field.clone()
242    }
243}
244
245/// Creates a schema for a join operation.
246/// The fields from the left side are first
247pub fn build_join_schema(
248    left: &Schema,
249    right: &Schema,
250    join_type: &JoinType,
251) -> (Schema, Vec<ColumnIndex>) {
252    let left_fields = || {
253        left.fields()
254            .iter()
255            .map(|f| output_join_field(f, join_type, true))
256            .enumerate()
257            .map(|(index, f)| {
258                (
259                    f,
260                    ColumnIndex {
261                        index,
262                        side: JoinSide::Left,
263                    },
264                )
265            })
266    };
267
268    let right_fields = || {
269        right
270            .fields()
271            .iter()
272            .map(|f| output_join_field(f, join_type, false))
273            .enumerate()
274            .map(|(index, f)| {
275                (
276                    f,
277                    ColumnIndex {
278                        index,
279                        side: JoinSide::Right,
280                    },
281                )
282            })
283    };
284
285    let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
286        JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
287            // left then right
288            left_fields().chain(right_fields()).unzip()
289        }
290        JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(),
291        JoinType::LeftMark => {
292            let right_field = once((
293                Field::new("mark", DataType::Boolean, false),
294                ColumnIndex {
295                    index: 0,
296                    side: JoinSide::None,
297                },
298            ));
299            left_fields().chain(right_field).unzip()
300        }
301        JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(),
302        JoinType::RightMark => {
303            let left_field = once((
304                Field::new("mark", DataType::Boolean, false),
305                ColumnIndex {
306                    index: 0,
307                    side: JoinSide::None,
308                },
309            ));
310            right_fields().chain(left_field).unzip()
311        }
312    };
313
314    let (schema1, schema2) = match join_type {
315        JoinType::Right
316        | JoinType::RightSemi
317        | JoinType::RightAnti
318        | JoinType::RightMark => (left, right),
319        _ => (right, left),
320    };
321
322    let metadata = schema1
323        .metadata()
324        .clone()
325        .into_iter()
326        .chain(schema2.metadata().clone())
327        .collect();
328
329    (fields.finish().with_metadata(metadata), column_indices)
330}
331
332/// A [`OnceAsync`] runs an `async` closure once, where multiple calls to
333/// [`OnceAsync::try_once`] return a [`OnceFut`] that resolves to the result of the
334/// same computation.
335///
336/// This is useful for joins where the results of one child are needed to proceed
337/// with multiple output stream
338///
339///
340/// For example, in a hash join, one input is buffered and shared across
341/// potentially multiple output partitions. Each output partition must wait for
342/// the hash table to be built before proceeding.
343///
344/// Each output partition waits on the same `OnceAsync` before proceeding.
345pub(crate) struct OnceAsync<T> {
346    fut: Mutex<Option<SharedResult<OnceFut<T>>>>,
347}
348
349impl<T> Default for OnceAsync<T> {
350    fn default() -> Self {
351        Self {
352            fut: Mutex::new(None),
353        }
354    }
355}
356
357impl<T> Debug for OnceAsync<T> {
358    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359        write!(f, "OnceAsync")
360    }
361}
362
363impl<T: 'static> OnceAsync<T> {
364    /// If this is the first call to this function on this object, will invoke
365    /// `f` to obtain a future and return a [`OnceFut`] referring to this. `f`
366    /// may fail, in which case its error is returned.
367    ///
368    /// If this is not the first call, will return a [`OnceFut`] referring
369    /// to the same future as was returned by the first call - or the same
370    /// error if the initial call to `f` failed.
371    pub(crate) fn try_once<F, Fut>(&self, f: F) -> Result<OnceFut<T>>
372    where
373        F: FnOnce() -> Result<Fut>,
374        Fut: Future<Output = Result<T>> + Send + 'static,
375    {
376        self.fut
377            .lock()
378            .get_or_insert_with(|| f().map(OnceFut::new).map_err(Arc::new))
379            .clone()
380            .map_err(DataFusionError::Shared)
381    }
382}
383
384/// The shared future type used internally within [`OnceAsync`]
385type OnceFutPending<T> = Shared<BoxFuture<'static, SharedResult<Arc<T>>>>;
386
387/// A [`OnceFut`] represents a shared asynchronous computation, that will be evaluated
388/// once for all [`Clone`]'s, with [`OnceFut::get`] providing a non-consuming interface
389/// to drive the underlying [`Future`] to completion
390pub(crate) struct OnceFut<T> {
391    state: OnceFutState<T>,
392}
393
394impl<T> Clone for OnceFut<T> {
395    fn clone(&self) -> Self {
396        Self {
397            state: self.state.clone(),
398        }
399    }
400}
401
402/// A shared state between statistic aggregators for a join
403/// operation.
404#[derive(Clone, Debug, Default)]
405struct PartialJoinStatistics {
406    pub num_rows: usize,
407    pub column_statistics: Vec<ColumnStatistics>,
408}
409
410/// Estimate the statistics for the given join's output.
411pub(crate) fn estimate_join_statistics(
412    left_stats: Statistics,
413    right_stats: Statistics,
414    on: JoinOn,
415    join_type: &JoinType,
416    schema: &Schema,
417) -> Result<Statistics> {
418    let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on);
419    let (num_rows, column_statistics) = match join_stats {
420        Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics),
421        None => (Precision::Absent, Statistics::unknown_column(schema)),
422    };
423    Ok(Statistics {
424        num_rows,
425        total_byte_size: Precision::Absent,
426        column_statistics,
427    })
428}
429
430// Estimate the cardinality for the given join with input statistics.
431fn estimate_join_cardinality(
432    join_type: &JoinType,
433    left_stats: Statistics,
434    right_stats: Statistics,
435    on: &JoinOn,
436) -> Option<PartialJoinStatistics> {
437    let (left_col_stats, right_col_stats) = on
438        .iter()
439        .map(|(left, right)| {
440            match (
441                left.as_any().downcast_ref::<Column>(),
442                right.as_any().downcast_ref::<Column>(),
443            ) {
444                (Some(left), Some(right)) => (
445                    left_stats.column_statistics[left.index()].clone(),
446                    right_stats.column_statistics[right.index()].clone(),
447                ),
448                _ => (
449                    ColumnStatistics::new_unknown(),
450                    ColumnStatistics::new_unknown(),
451                ),
452            }
453        })
454        .unzip::<_, _, Vec<_>, Vec<_>>();
455
456    match join_type {
457        JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
458            let ij_cardinality = estimate_inner_join_cardinality(
459                Statistics {
460                    num_rows: left_stats.num_rows,
461                    total_byte_size: Precision::Absent,
462                    column_statistics: left_col_stats,
463                },
464                Statistics {
465                    num_rows: right_stats.num_rows,
466                    total_byte_size: Precision::Absent,
467                    column_statistics: right_col_stats,
468                },
469            )?;
470
471            // The cardinality for inner join can also be used to estimate
472            // the cardinality of left/right/full outer joins as long as it
473            // it is greater than the minimum cardinality constraints of these
474            // joins (so that we don't underestimate the cardinality).
475            let cardinality = match join_type {
476                JoinType::Inner => ij_cardinality,
477                JoinType::Left => ij_cardinality.max(&left_stats.num_rows),
478                JoinType::Right => ij_cardinality.max(&right_stats.num_rows),
479                JoinType::Full => ij_cardinality
480                    .max(&left_stats.num_rows)
481                    .add(&ij_cardinality.max(&right_stats.num_rows))
482                    .sub(&ij_cardinality),
483                _ => unreachable!(),
484            };
485
486            Some(PartialJoinStatistics {
487                num_rows: *cardinality.get_value()?,
488                // We don't do anything specific here, just combine the existing
489                // statistics which might yield subpar results (although it is
490                // true, esp regarding min/max). For a better estimation, we need
491                // filter selectivity analysis first.
492                column_statistics: left_stats
493                    .column_statistics
494                    .into_iter()
495                    .chain(right_stats.column_statistics)
496                    .collect(),
497            })
498        }
499
500        // For SemiJoins estimation result is either zero, in cases when inputs
501        // are non-overlapping according to statistics, or equal to number of rows
502        // for outer input
503        JoinType::LeftSemi | JoinType::RightSemi => {
504            let (outer_stats, inner_stats) = match join_type {
505                JoinType::LeftSemi => (left_stats, right_stats),
506                _ => (right_stats, left_stats),
507            };
508            let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) {
509                Some(estimation) => *estimation.get_value()?,
510                None => *outer_stats.num_rows.get_value()?,
511            };
512
513            Some(PartialJoinStatistics {
514                num_rows: cardinality,
515                column_statistics: outer_stats.column_statistics,
516            })
517        }
518
519        // For AntiJoins estimation always equals to outer statistics, as
520        // non-overlapping inputs won't affect estimation
521        JoinType::LeftAnti | JoinType::RightAnti => {
522            let outer_stats = match join_type {
523                JoinType::LeftAnti => left_stats,
524                _ => right_stats,
525            };
526
527            Some(PartialJoinStatistics {
528                num_rows: *outer_stats.num_rows.get_value()?,
529                column_statistics: outer_stats.column_statistics,
530            })
531        }
532
533        JoinType::LeftMark => {
534            let num_rows = *left_stats.num_rows.get_value()?;
535            let mut column_statistics = left_stats.column_statistics;
536            column_statistics.push(ColumnStatistics::new_unknown());
537            Some(PartialJoinStatistics {
538                num_rows,
539                column_statistics,
540            })
541        }
542        JoinType::RightMark => {
543            let num_rows = *right_stats.num_rows.get_value()?;
544            let mut column_statistics = right_stats.column_statistics;
545            column_statistics.push(ColumnStatistics::new_unknown());
546            Some(PartialJoinStatistics {
547                num_rows,
548                column_statistics,
549            })
550        }
551    }
552}
553
554/// Estimate the inner join cardinality by using the basic building blocks of
555/// column-level statistics and the total row count. This is a very naive and
556/// a very conservative implementation that can quickly give up if there is not
557/// enough input statistics.
558fn estimate_inner_join_cardinality(
559    left_stats: Statistics,
560    right_stats: Statistics,
561) -> Option<Precision<usize>> {
562    // Immediately return if inputs considered as non-overlapping
563    if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) {
564        return Some(estimation);
565    };
566
567    // The algorithm here is partly based on the non-histogram selectivity estimation
568    // from Spark's Catalyst optimizer.
569    let mut join_selectivity = Precision::Absent;
570    for (left_stat, right_stat) in left_stats
571        .column_statistics
572        .iter()
573        .zip(right_stats.column_statistics.iter())
574    {
575        let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat);
576        let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat);
577        let max_distinct = left_max_distinct.max(&right_max_distinct);
578        if max_distinct.get_value().is_some() {
579            // Seems like there are a few implementations of this algorithm that implement
580            // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs
581            // further exploration.
582            join_selectivity = max_distinct;
583        }
584    }
585
586    // With the assumption that the smaller input's domain is generally represented in the bigger
587    // input's domain, we can estimate the inner join's cardinality by taking the cartesian product
588    // of the two inputs and normalizing it by the selectivity factor.
589    let left_num_rows = left_stats.num_rows.get_value()?;
590    let right_num_rows = right_stats.num_rows.get_value()?;
591    match join_selectivity {
592        Precision::Exact(value) if value > 0 => {
593            Some(Precision::Exact((left_num_rows * right_num_rows) / value))
594        }
595        Precision::Inexact(value) if value > 0 => {
596            Some(Precision::Inexact((left_num_rows * right_num_rows) / value))
597        }
598        // Since we don't have any information about the selectivity (which is derived
599        // from the number of distinct rows information) we can give up here for now.
600        // And let other passes handle this (otherwise we would need to produce an
601        // overestimation using just the cartesian product).
602        _ => None,
603    }
604}
605
606/// Estimates if inputs are non-overlapping, using input statistics.
607/// If inputs are disjoint, returns zero estimation, otherwise returns None
608fn estimate_disjoint_inputs(
609    left_stats: &Statistics,
610    right_stats: &Statistics,
611) -> Option<Precision<usize>> {
612    for (left_stat, right_stat) in left_stats
613        .column_statistics
614        .iter()
615        .zip(right_stats.column_statistics.iter())
616    {
617        // If there is no overlap in any of the join columns, this means the join
618        // itself is disjoint and the cardinality is 0. Though we can only assume
619        // this when the statistics are exact (since it is a very strong assumption).
620        let left_min_val = left_stat.min_value.get_value();
621        let right_max_val = right_stat.max_value.get_value();
622        if left_min_val.is_some()
623            && right_max_val.is_some()
624            && left_min_val > right_max_val
625        {
626            return Some(
627                if left_stat.min_value.is_exact().unwrap_or(false)
628                    && right_stat.max_value.is_exact().unwrap_or(false)
629                {
630                    Precision::Exact(0)
631                } else {
632                    Precision::Inexact(0)
633                },
634            );
635        }
636
637        let left_max_val = left_stat.max_value.get_value();
638        let right_min_val = right_stat.min_value.get_value();
639        if left_max_val.is_some()
640            && right_min_val.is_some()
641            && left_max_val < right_min_val
642        {
643            return Some(
644                if left_stat.max_value.is_exact().unwrap_or(false)
645                    && right_stat.min_value.is_exact().unwrap_or(false)
646                {
647                    Precision::Exact(0)
648                } else {
649                    Precision::Inexact(0)
650                },
651            );
652        }
653    }
654
655    None
656}
657
658/// Estimate the number of maximum distinct values that can be present in the
659/// given column from its statistics. If distinct_count is available, uses it
660/// directly. Otherwise, if the column is numeric and has min/max values, it
661/// estimates the maximum distinct count from those. Otherwise, the num_rows
662/// is used.
663fn max_distinct_count(
664    num_rows: &Precision<usize>,
665    stats: &ColumnStatistics,
666) -> Precision<usize> {
667    match &stats.distinct_count {
668        &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc,
669        _ => {
670            // The number can never be greater than the number of rows we have
671            // minus the nulls (since they don't count as distinct values).
672            let result = match num_rows {
673                Precision::Absent => Precision::Absent,
674                Precision::Inexact(count) => {
675                    // To safeguard against inexact number of rows (e.g. 0) being smaller than
676                    // an exact null count we need to do a checked subtraction.
677                    match count.checked_sub(*stats.null_count.get_value().unwrap_or(&0)) {
678                        None => Precision::Inexact(0),
679                        Some(non_null_count) => Precision::Inexact(non_null_count),
680                    }
681                }
682                Precision::Exact(count) => {
683                    let count = count - stats.null_count.get_value().unwrap_or(&0);
684                    if stats.null_count.is_exact().unwrap_or(false) {
685                        Precision::Exact(count)
686                    } else {
687                        Precision::Inexact(count)
688                    }
689                }
690            };
691            // Cap the estimate using the number of possible values:
692            if let (Some(min), Some(max)) =
693                (stats.min_value.get_value(), stats.max_value.get_value())
694            {
695                if let Some(range_dc) = Interval::try_new(min.clone(), max.clone())
696                    .ok()
697                    .and_then(|e| e.cardinality())
698                {
699                    let range_dc = range_dc as usize;
700                    // Note that the `unwrap` calls in the below statement are safe.
701                    return if matches!(result, Precision::Absent)
702                        || &range_dc < result.get_value().unwrap()
703                    {
704                        if stats.min_value.is_exact().unwrap()
705                            && stats.max_value.is_exact().unwrap()
706                        {
707                            Precision::Exact(range_dc)
708                        } else {
709                            Precision::Inexact(range_dc)
710                        }
711                    } else {
712                        result
713                    };
714                }
715            }
716
717            result
718        }
719    }
720}
721
722enum OnceFutState<T> {
723    Pending(OnceFutPending<T>),
724    Ready(SharedResult<Arc<T>>),
725}
726
727impl<T> Clone for OnceFutState<T> {
728    fn clone(&self) -> Self {
729        match self {
730            Self::Pending(p) => Self::Pending(p.clone()),
731            Self::Ready(r) => Self::Ready(r.clone()),
732        }
733    }
734}
735
736impl<T: 'static> OnceFut<T> {
737    /// Create a new [`OnceFut`] from a [`Future`]
738    pub(crate) fn new<Fut>(fut: Fut) -> Self
739    where
740        Fut: Future<Output = Result<T>> + Send + 'static,
741    {
742        Self {
743            state: OnceFutState::Pending(
744                fut.map(|res| res.map(Arc::new).map_err(Arc::new))
745                    .boxed()
746                    .shared(),
747            ),
748        }
749    }
750
751    /// Get the result of the computation if it is ready, without consuming it
752    pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll<Result<&T>> {
753        if let OnceFutState::Pending(fut) = &mut self.state {
754            let r = ready!(fut.poll_unpin(cx));
755            self.state = OnceFutState::Ready(r);
756        }
757
758        // Cannot use loop as this would trip up the borrow checker
759        match &self.state {
760            OnceFutState::Pending(_) => unreachable!(),
761            OnceFutState::Ready(r) => Poll::Ready(
762                r.as_ref()
763                    .map(|r| r.as_ref())
764                    .map_err(DataFusionError::from),
765            ),
766        }
767    }
768
769    /// Get shared reference to the result of the computation if it is ready, without consuming it
770    pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll<Result<Arc<T>>> {
771        if let OnceFutState::Pending(fut) = &mut self.state {
772            let r = ready!(fut.poll_unpin(cx));
773            self.state = OnceFutState::Ready(r);
774        }
775
776        match &self.state {
777            OnceFutState::Pending(_) => unreachable!(),
778            OnceFutState::Ready(r) => {
779                Poll::Ready(r.clone().map_err(DataFusionError::Shared))
780            }
781        }
782    }
783}
784
785/// Should we use a bitmap to track each incoming right batch's each row's
786/// 'joined' status.
787///
788/// For example in right joins, we have to use a bit map to track matched
789/// right side rows, and later enter a `EmitRightUnmatched` stage to emit
790/// unmatched right rows.
791pub(crate) fn need_produce_right_in_final(join_type: JoinType) -> bool {
792    matches!(
793        join_type,
794        JoinType::Full
795            | JoinType::Right
796            | JoinType::RightAnti
797            | JoinType::RightMark
798            | JoinType::RightSemi
799    )
800}
801
802/// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and
803/// use the bit map to generate the part of result of the join.
804///
805/// For example of the `Left` join, in each iteration of right side, can get the matched result, but need
806/// to maintain the matched indices bit map to get the unmatched row for the left side.
807pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
808    matches!(
809        join_type,
810        JoinType::Left
811            | JoinType::LeftAnti
812            | JoinType::LeftSemi
813            | JoinType::LeftMark
814            | JoinType::Full
815    )
816}
817
818pub(crate) fn get_final_indices_from_shared_bitmap(
819    shared_bitmap: &SharedBitmapBuilder,
820    join_type: JoinType,
821    piecewise: bool,
822) -> (UInt64Array, UInt32Array) {
823    let bitmap = shared_bitmap.lock();
824    get_final_indices_from_bit_map(&bitmap, join_type, piecewise)
825}
826
827/// In the end of join execution, need to use bit map of the matched
828/// indices to generate the final left and right indices.
829///
830/// For example:
831///
832/// 1. left_bit_map: `[true, false, true, true, false]`
833/// 2. join_type: `Left`
834///
835/// The result is: `([1,4], [null, null])`
836pub(crate) fn get_final_indices_from_bit_map(
837    left_bit_map: &BooleanBufferBuilder,
838    join_type: JoinType,
839    // We add a flag for whether this is being passed from the `PiecewiseMergeJoin`
840    // because the bitmap can be for left + right `JoinType`s
841    piecewise: bool,
842) -> (UInt64Array, UInt32Array) {
843    let left_size = left_bit_map.len();
844    if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise)
845    {
846        let left_indices = (0..left_size as u64).collect::<UInt64Array>();
847        let right_indices = (0..left_size)
848            .map(|idx| left_bit_map.get_bit(idx).then_some(0))
849            .collect::<UInt32Array>();
850        return (left_indices, right_indices);
851    }
852    let left_indices = if join_type == JoinType::LeftSemi
853        || (join_type == JoinType::RightSemi && piecewise)
854    {
855        (0..left_size)
856            .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
857            .collect::<UInt64Array>()
858    } else {
859        // just for `Left`, `LeftAnti` and `Full` join
860        // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally
861        (0..left_size)
862            .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
863            .collect::<UInt64Array>()
864    };
865    // right_indices
866    // all the element in the right side is None
867    let mut builder = UInt32Builder::with_capacity(left_indices.len());
868    builder.append_nulls(left_indices.len());
869    let right_indices = builder.finish();
870    (left_indices, right_indices)
871}
872
873pub(crate) fn apply_join_filter_to_indices(
874    build_input_buffer: &RecordBatch,
875    probe_batch: &RecordBatch,
876    build_indices: UInt64Array,
877    probe_indices: UInt32Array,
878    filter: &JoinFilter,
879    build_side: JoinSide,
880    max_intermediate_size: Option<usize>,
881) -> Result<(UInt64Array, UInt32Array)> {
882    if build_indices.is_empty() && probe_indices.is_empty() {
883        return Ok((build_indices, probe_indices));
884    };
885
886    let filter_result = if let Some(max_size) = max_intermediate_size {
887        let mut filter_results =
888            Vec::with_capacity(build_indices.len().div_ceil(max_size));
889
890        for i in (0..build_indices.len()).step_by(max_size) {
891            let end = min(build_indices.len(), i + max_size);
892            let len = end - i;
893            let intermediate_batch = build_batch_from_indices(
894                filter.schema(),
895                build_input_buffer,
896                probe_batch,
897                &build_indices.slice(i, len),
898                &probe_indices.slice(i, len),
899                filter.column_indices(),
900                build_side,
901            )?;
902            let filter_result = filter
903                .expression()
904                .evaluate(&intermediate_batch)?
905                .into_array(intermediate_batch.num_rows())?;
906            filter_results.push(filter_result);
907        }
908
909        let filter_refs: Vec<&dyn Array> =
910            filter_results.iter().map(|a| a.as_ref()).collect();
911
912        compute::concat(&filter_refs)?
913    } else {
914        let intermediate_batch = build_batch_from_indices(
915            filter.schema(),
916            build_input_buffer,
917            probe_batch,
918            &build_indices,
919            &probe_indices,
920            filter.column_indices(),
921            build_side,
922        )?;
923
924        filter
925            .expression()
926            .evaluate(&intermediate_batch)?
927            .into_array(intermediate_batch.num_rows())?
928    };
929
930    let mask = as_boolean_array(&filter_result)?;
931
932    let left_filtered = compute::filter(&build_indices, mask)?;
933    let right_filtered = compute::filter(&probe_indices, mask)?;
934    Ok((
935        downcast_array(left_filtered.as_ref()),
936        downcast_array(right_filtered.as_ref()),
937    ))
938}
939
940/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`.
941/// The resulting batch has [Schema] `schema`.
942pub(crate) fn build_batch_from_indices(
943    schema: &Schema,
944    build_input_buffer: &RecordBatch,
945    probe_batch: &RecordBatch,
946    build_indices: &UInt64Array,
947    probe_indices: &UInt32Array,
948    column_indices: &[ColumnIndex],
949    build_side: JoinSide,
950) -> Result<RecordBatch> {
951    if schema.fields().is_empty() {
952        let options = RecordBatchOptions::new()
953            .with_match_field_names(true)
954            .with_row_count(Some(build_indices.len()));
955
956        return Ok(RecordBatch::try_new_with_options(
957            Arc::new(schema.clone()),
958            vec![],
959            &options,
960        )?);
961    }
962
963    // build the columns of the new [RecordBatch]:
964    // 1. pick whether the column is from the left or right
965    // 2. based on the pick, `take` items from the different RecordBatches
966    let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
967
968    for column_index in column_indices {
969        let array = if column_index.side == JoinSide::None {
970            // For mark joins, the mark column is a true if the indices is not null, otherwise it will be false
971            Arc::new(compute::is_not_null(probe_indices)?)
972        } else if column_index.side == build_side {
973            let array = build_input_buffer.column(column_index.index);
974            if array.is_empty() || build_indices.null_count() == build_indices.len() {
975                // Outer join would generate a null index when finding no match at our side.
976                // Therefore, it's possible we are empty but need to populate an n-length null array,
977                // where n is the length of the index array.
978                assert_eq!(build_indices.null_count(), build_indices.len());
979                new_null_array(array.data_type(), build_indices.len())
980            } else {
981                take(array.as_ref(), build_indices, None)?
982            }
983        } else {
984            let array = probe_batch.column(column_index.index);
985            if array.is_empty() || probe_indices.null_count() == probe_indices.len() {
986                assert_eq!(probe_indices.null_count(), probe_indices.len());
987                new_null_array(array.data_type(), probe_indices.len())
988            } else {
989                take(array.as_ref(), probe_indices, None)?
990            }
991        };
992
993        columns.push(array);
994    }
995    Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
996}
997
998/// Returns a new [RecordBatch] resulting of a join where the build/left side is empty.
999/// The resulting batch has [Schema] `schema`.
1000pub(crate) fn build_batch_empty_build_side(
1001    schema: &Schema,
1002    build_batch: &RecordBatch,
1003    probe_batch: &RecordBatch,
1004    column_indices: &[ColumnIndex],
1005    join_type: JoinType,
1006) -> Result<RecordBatch> {
1007    match join_type {
1008        // these join types only return data if the left side is not empty, so we return an
1009        // empty RecordBatch
1010        JoinType::Inner
1011        | JoinType::Left
1012        | JoinType::LeftSemi
1013        | JoinType::RightSemi
1014        | JoinType::LeftAnti
1015        | JoinType::LeftMark => Ok(RecordBatch::new_empty(Arc::new(schema.clone()))),
1016
1017        // the remaining joins will return data for the right columns and null for the left ones
1018        JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => {
1019            let num_rows = probe_batch.num_rows();
1020            let mut columns: Vec<Arc<dyn Array>> =
1021                Vec::with_capacity(schema.fields().len());
1022
1023            for column_index in column_indices {
1024                let array = match column_index.side {
1025                    // left -> null array
1026                    JoinSide::Left => new_null_array(
1027                        build_batch.column(column_index.index).data_type(),
1028                        num_rows,
1029                    ),
1030                    // right -> respective right array
1031                    JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)),
1032                    // right mark -> unset boolean array as there are no matches on the left side
1033                    JoinSide::None => Arc::new(BooleanArray::new(
1034                        BooleanBuffer::new_unset(num_rows),
1035                        None,
1036                    )),
1037                };
1038
1039                columns.push(array);
1040            }
1041
1042            Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
1043        }
1044    }
1045}
1046
1047/// The input is the matched indices for left and right and
1048/// adjust the indices according to the join type
1049pub(crate) fn adjust_indices_by_join_type(
1050    left_indices: UInt64Array,
1051    right_indices: UInt32Array,
1052    adjust_range: Range<usize>,
1053    join_type: JoinType,
1054    preserve_order_for_right: bool,
1055) -> Result<(UInt64Array, UInt32Array)> {
1056    match join_type {
1057        JoinType::Inner => {
1058            // matched
1059            Ok((left_indices, right_indices))
1060        }
1061        JoinType::Left => {
1062            // matched
1063            Ok((left_indices, right_indices))
1064            // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap
1065        }
1066        JoinType::Right => {
1067            // combine the matched and unmatched right result together
1068            append_right_indices(
1069                left_indices,
1070                right_indices,
1071                adjust_range,
1072                preserve_order_for_right,
1073            )
1074        }
1075        JoinType::Full => {
1076            append_right_indices(left_indices, right_indices, adjust_range, false)
1077        }
1078        JoinType::RightSemi => {
1079            // need to remove the duplicated record in the right side
1080            let right_indices = get_semi_indices(adjust_range, &right_indices);
1081            // the left_indices will not be used later for the `right semi` join
1082            Ok((left_indices, right_indices))
1083        }
1084        JoinType::RightAnti => {
1085            // need to remove the duplicated record in the right side
1086            // get the anti index for the right side
1087            let right_indices = get_anti_indices(adjust_range, &right_indices);
1088            // the left_indices will not be used later for the `right anti` join
1089            Ok((left_indices, right_indices))
1090        }
1091        JoinType::RightMark => {
1092            let right_indices = get_mark_indices(&adjust_range, &right_indices);
1093            let left_indices_vec: Vec<u64> = adjust_range.map(|i| i as u64).collect();
1094            let left_indices = UInt64Array::from(left_indices_vec);
1095            Ok((left_indices, right_indices))
1096        }
1097        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1098            // matched or unmatched left row will be produced in the end of loop
1099            // When visit the right batch, we can output the matched left row and don't need to wait the end of loop
1100            Ok((
1101                UInt64Array::from_iter_values(vec![]),
1102                UInt32Array::from_iter_values(vec![]),
1103            ))
1104        }
1105    }
1106}
1107
1108/// Appends right indices to left indices based on the specified order mode.
1109///
1110/// The function operates in two modes:
1111/// 1. If `preserve_order_for_right` is true, probe matched and unmatched indices
1112///    are inserted in order using the `append_probe_indices_in_order()` method.
1113/// 2. Otherwise, unmatched probe indices are simply appended after matched ones.
1114///
1115/// # Parameters
1116/// - `left_indices`: UInt64Array of left indices.
1117/// - `right_indices`: UInt32Array of right indices.
1118/// - `adjust_range`: Range to adjust the right indices.
1119/// - `preserve_order_for_right`: Boolean flag to determine the mode of operation.
1120///
1121/// # Returns
1122/// A tuple of updated `UInt64Array` and `UInt32Array`.
1123pub(crate) fn append_right_indices(
1124    left_indices: UInt64Array,
1125    right_indices: UInt32Array,
1126    adjust_range: Range<usize>,
1127    preserve_order_for_right: bool,
1128) -> Result<(UInt64Array, UInt32Array)> {
1129    if preserve_order_for_right {
1130        Ok(append_probe_indices_in_order(
1131            left_indices,
1132            right_indices,
1133            adjust_range,
1134        ))
1135    } else {
1136        let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices);
1137
1138        if right_unmatched_indices.is_empty() {
1139            Ok((left_indices, right_indices))
1140        } else {
1141            // `into_builder()` can fail here when there is nothing to be filtered and
1142            // left_indices or right_indices has the same reference to the cached indices.
1143            // In that case, we use a slower alternative.
1144
1145            // the new left indices: left_indices + null array
1146            let mut new_left_indices_builder =
1147                left_indices.into_builder().unwrap_or_else(|left_indices| {
1148                    let mut builder = UInt64Builder::with_capacity(
1149                        left_indices.len() + right_unmatched_indices.len(),
1150                    );
1151                    debug_assert_eq!(
1152                        left_indices.null_count(),
1153                        0,
1154                        "expected left indices to have no nulls"
1155                    );
1156                    builder.append_slice(left_indices.values());
1157                    builder
1158                });
1159            new_left_indices_builder.append_nulls(right_unmatched_indices.len());
1160            let new_left_indices = UInt64Array::from(new_left_indices_builder.finish());
1161
1162            // the new right indices: right_indices + right_unmatched_indices
1163            let mut new_right_indices_builder = right_indices
1164                .into_builder()
1165                .unwrap_or_else(|right_indices| {
1166                    let mut builder = UInt32Builder::with_capacity(
1167                        right_indices.len() + right_unmatched_indices.len(),
1168                    );
1169                    debug_assert_eq!(
1170                        right_indices.null_count(),
1171                        0,
1172                        "expected right indices to have no nulls"
1173                    );
1174                    builder.append_slice(right_indices.values());
1175                    builder
1176                });
1177            debug_assert_eq!(
1178                right_unmatched_indices.null_count(),
1179                0,
1180                "expected right unmatched indices to have no nulls"
1181            );
1182            new_right_indices_builder.append_slice(right_unmatched_indices.values());
1183            let new_right_indices = UInt32Array::from(new_right_indices_builder.finish());
1184
1185            Ok((new_left_indices, new_right_indices))
1186        }
1187    }
1188}
1189
1190/// Returns `range` indices which are not present in `input_indices`
1191pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
1192    range: Range<usize>,
1193    input_indices: &PrimitiveArray<T>,
1194) -> PrimitiveArray<T>
1195where
1196    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1197{
1198    let bitmap = build_range_bitmap(&range, input_indices);
1199    let offset = range.start;
1200
1201    // get the anti index
1202    (range)
1203        .filter_map(|idx| {
1204            (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1205        })
1206        .collect()
1207}
1208
1209/// Returns intersection of `range` and `input_indices` omitting duplicates
1210pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
1211    range: Range<usize>,
1212    input_indices: &PrimitiveArray<T>,
1213) -> PrimitiveArray<T>
1214where
1215    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1216{
1217    let bitmap = build_range_bitmap(&range, input_indices);
1218    let offset = range.start;
1219    // get the semi index
1220    (range)
1221        .filter_map(|idx| {
1222            (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1223        })
1224        .collect()
1225}
1226
1227pub(crate) fn get_mark_indices<T: ArrowPrimitiveType>(
1228    range: &Range<usize>,
1229    input_indices: &PrimitiveArray<T>,
1230) -> PrimitiveArray<UInt32Type>
1231where
1232    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1233{
1234    let mut bitmap = build_range_bitmap(range, input_indices);
1235    PrimitiveArray::new(
1236        vec![0; range.len()].into(),
1237        Some(NullBuffer::new(bitmap.finish())),
1238    )
1239}
1240
1241fn build_range_bitmap<T: ArrowPrimitiveType>(
1242    range: &Range<usize>,
1243    input: &PrimitiveArray<T>,
1244) -> BooleanBufferBuilder {
1245    let mut builder = BooleanBufferBuilder::new(range.len());
1246    builder.append_n(range.len(), false);
1247
1248    input.iter().flatten().for_each(|v| {
1249        let idx = v.as_usize();
1250        if range.contains(&idx) {
1251            builder.set_bit(idx - range.start, true);
1252        }
1253    });
1254
1255    builder
1256}
1257
1258/// Appends probe indices in order by considering the given build indices.
1259///
1260/// This function constructs new build and probe indices by iterating through
1261/// the provided indices, and appends any missing values between previous and
1262/// current probe index with a corresponding null build index.
1263///
1264/// # Parameters
1265///
1266/// - `build_indices`: `PrimitiveArray` of `UInt64Type` containing build indices.
1267/// - `probe_indices`: `PrimitiveArray` of `UInt32Type` containing probe indices.
1268/// - `range`: The range of indices to consider.
1269///
1270/// # Returns
1271///
1272/// A tuple of two arrays:
1273/// - A `PrimitiveArray` of `UInt64Type` with the newly constructed build indices.
1274/// - A `PrimitiveArray` of `UInt32Type` with the newly constructed probe indices.
1275fn append_probe_indices_in_order(
1276    build_indices: PrimitiveArray<UInt64Type>,
1277    probe_indices: PrimitiveArray<UInt32Type>,
1278    range: Range<usize>,
1279) -> (PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>) {
1280    // Builders for new indices:
1281    let mut new_build_indices = UInt64Builder::new();
1282    let mut new_probe_indices = UInt32Builder::new();
1283    // Set previous index as the start index for the initial loop:
1284    let mut prev_index = range.start as u32;
1285    // Zip the two iterators.
1286    debug_assert!(build_indices.len() == probe_indices.len());
1287    for (build_index, probe_index) in build_indices
1288        .values()
1289        .into_iter()
1290        .zip(probe_indices.values().into_iter())
1291    {
1292        // Append values between previous and current probe index with null build index:
1293        for value in prev_index..*probe_index {
1294            new_probe_indices.append_value(value);
1295            new_build_indices.append_null();
1296        }
1297        // Append current indices:
1298        new_probe_indices.append_value(*probe_index);
1299        new_build_indices.append_value(*build_index);
1300        // Set current probe index as previous for the next iteration:
1301        prev_index = probe_index + 1;
1302    }
1303    // Append remaining probe indices after the last valid probe index with null build index.
1304    for value in prev_index..range.end as u32 {
1305        new_probe_indices.append_value(value);
1306        new_build_indices.append_null();
1307    }
1308    // Build arrays and return:
1309    (new_build_indices.finish(), new_probe_indices.finish())
1310}
1311
1312/// Metrics for build & probe joins
1313#[derive(Clone, Debug)]
1314pub(crate) struct BuildProbeJoinMetrics {
1315    pub(crate) baseline: BaselineMetrics,
1316    /// Total time for collecting build-side of join
1317    pub(crate) build_time: metrics::Time,
1318    /// Number of batches consumed by build-side
1319    pub(crate) build_input_batches: metrics::Count,
1320    /// Number of rows consumed by build-side
1321    pub(crate) build_input_rows: metrics::Count,
1322    /// Memory used by build-side in bytes
1323    pub(crate) build_mem_used: metrics::Gauge,
1324    /// Total time for joining probe-side batches to the build-side batches
1325    pub(crate) join_time: metrics::Time,
1326    /// Number of batches consumed by probe-side of this operator
1327    pub(crate) input_batches: metrics::Count,
1328    /// Number of rows consumed by probe-side this operator
1329    pub(crate) input_rows: metrics::Count,
1330    /// Number of batches produced by this operator
1331    pub(crate) output_batches: metrics::Count,
1332}
1333
1334// This Drop implementation updates the elapsed compute part of the metrics.
1335//
1336// Why is this in a Drop?
1337// - We keep track of build_time and join_time separately, but baseline metrics have
1338// a total elapsed_compute time. Instead of remembering to update both the metrics
1339// at the same time, we chose to update elapsed_compute once at the end - summing up
1340// both the parts.
1341//
1342// How does this work?
1343// - The elapsed_compute `Time` is represented by an `Arc<AtomicUsize>`. So even when
1344// this `BuildProbeJoinMetrics` is dropped, the elapsed_compute is usable through the
1345// Arc reference.
1346impl Drop for BuildProbeJoinMetrics {
1347    fn drop(&mut self) {
1348        self.baseline.elapsed_compute().add(&self.build_time);
1349        self.baseline.elapsed_compute().add(&self.join_time);
1350    }
1351}
1352
1353impl BuildProbeJoinMetrics {
1354    pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
1355        let baseline = BaselineMetrics::new(metrics, partition);
1356
1357        let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
1358
1359        let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition);
1360
1361        let build_input_batches =
1362            MetricBuilder::new(metrics).counter("build_input_batches", partition);
1363
1364        let build_input_rows =
1365            MetricBuilder::new(metrics).counter("build_input_rows", partition);
1366
1367        let build_mem_used =
1368            MetricBuilder::new(metrics).gauge("build_mem_used", partition);
1369
1370        let input_batches =
1371            MetricBuilder::new(metrics).counter("input_batches", partition);
1372
1373        let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
1374
1375        let output_batches =
1376            MetricBuilder::new(metrics).counter("output_batches", partition);
1377
1378        Self {
1379            build_time,
1380            build_input_batches,
1381            build_input_rows,
1382            build_mem_used,
1383            join_time,
1384            input_batches,
1385            input_rows,
1386            output_batches,
1387            baseline,
1388        }
1389    }
1390}
1391
1392/// The `handle_state` macro is designed to process the result of a state-changing
1393/// operation. It operates on a `StatefulStreamResult` by matching its variants and
1394/// executing corresponding actions. This macro is used to streamline code that deals
1395/// with state transitions, reducing boilerplate and improving readability.
1396///
1397/// # Cases
1398///
1399/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the
1400///   stream join operation should proceed to the next step.
1401/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the
1402///   result, either yielding a value or indicating the stream is awaiting more
1403///   data.
1404/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
1405///   during the stream join operation.
1406///
1407/// # Arguments
1408///
1409/// * `$match_case`: An expression that evaluates to a `Result<StatefulStreamResult<_>>`.
1410#[macro_export]
1411macro_rules! handle_state {
1412    ($match_case:expr) => {
1413        match $match_case {
1414            Ok(StatefulStreamResult::Continue) => continue,
1415            Ok(StatefulStreamResult::Ready(result)) => {
1416                Poll::Ready(Ok(result).transpose())
1417            }
1418            Err(e) => Poll::Ready(Some(Err(e))),
1419        }
1420    };
1421}
1422
1423/// Represents the result of a stateful operation.
1424///
1425/// This enumeration indicates whether the state produced a result that is
1426/// ready for use (`Ready`) or if the operation requires continuation (`Continue`).
1427///
1428/// Variants:
1429/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`.
1430/// - `Continue`: Indicates that the operation is not yet complete and requires further
1431///   processing or more data. When this variant is returned, it typically means that the
1432///   current invocation of the state did not produce a final result, and the operation
1433///   should be invoked again later with more data and possibly with a different state.
1434pub enum StatefulStreamResult<T> {
1435    Ready(T),
1436    Continue,
1437}
1438
1439pub(crate) fn symmetric_join_output_partitioning(
1440    left: &Arc<dyn ExecutionPlan>,
1441    right: &Arc<dyn ExecutionPlan>,
1442    join_type: &JoinType,
1443) -> Result<Partitioning> {
1444    let left_columns_len = left.schema().fields.len();
1445    let left_partitioning = left.output_partitioning();
1446    let right_partitioning = right.output_partitioning();
1447    let result = match join_type {
1448        JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1449            left_partitioning.clone()
1450        }
1451        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1452            right_partitioning.clone()
1453        }
1454        JoinType::Inner | JoinType::Right => {
1455            adjust_right_output_partitioning(right_partitioning, left_columns_len)?
1456        }
1457        JoinType::Full => {
1458            // We could also use left partition count as they are necessarily equal.
1459            Partitioning::UnknownPartitioning(right_partitioning.partition_count())
1460        }
1461    };
1462    Ok(result)
1463}
1464
1465pub(crate) fn asymmetric_join_output_partitioning(
1466    left: &Arc<dyn ExecutionPlan>,
1467    right: &Arc<dyn ExecutionPlan>,
1468    join_type: &JoinType,
1469) -> Result<Partitioning> {
1470    let result = match join_type {
1471        JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
1472            right.output_partitioning(),
1473            left.schema().fields().len(),
1474        )?,
1475        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1476            right.output_partitioning().clone()
1477        }
1478        JoinType::Left
1479        | JoinType::LeftSemi
1480        | JoinType::LeftAnti
1481        | JoinType::Full
1482        | JoinType::LeftMark => Partitioning::UnknownPartitioning(
1483            right.output_partitioning().partition_count(),
1484        ),
1485    };
1486    Ok(result)
1487}
1488
1489/// Trait for incrementally generating Join output.
1490///
1491/// This trait is used to limit some join outputs
1492/// so it does not produce single large batches
1493pub(crate) trait BatchTransformer: Debug + Clone {
1494    /// Sets the next `RecordBatch` to be processed.
1495    fn set_batch(&mut self, batch: RecordBatch);
1496
1497    /// Retrieves the next `RecordBatch` from the transformer.
1498    /// Returns `None` if all batches have been produced.
1499    /// The boolean flag indicates whether the batch is the last one.
1500    fn next(&mut self) -> Option<(RecordBatch, bool)>;
1501}
1502
1503#[derive(Debug, Clone)]
1504/// A batch transformer that does nothing.
1505pub(crate) struct NoopBatchTransformer {
1506    /// RecordBatch to be processed
1507    batch: Option<RecordBatch>,
1508}
1509
1510impl NoopBatchTransformer {
1511    pub fn new() -> Self {
1512        Self { batch: None }
1513    }
1514}
1515
1516impl BatchTransformer for NoopBatchTransformer {
1517    fn set_batch(&mut self, batch: RecordBatch) {
1518        self.batch = Some(batch);
1519    }
1520
1521    fn next(&mut self) -> Option<(RecordBatch, bool)> {
1522        self.batch.take().map(|batch| (batch, true))
1523    }
1524}
1525
1526#[derive(Debug, Clone)]
1527/// Splits large batches into smaller batches with a maximum number of rows.
1528pub(crate) struct BatchSplitter {
1529    /// RecordBatch to be split
1530    batch: Option<RecordBatch>,
1531    /// Maximum number of rows in a split batch
1532    batch_size: usize,
1533    /// Current row index
1534    row_index: usize,
1535}
1536
1537impl BatchSplitter {
1538    /// Creates a new `BatchSplitter` with the specified batch size.
1539    pub(crate) fn new(batch_size: usize) -> Self {
1540        Self {
1541            batch: None,
1542            batch_size,
1543            row_index: 0,
1544        }
1545    }
1546}
1547
1548impl BatchTransformer for BatchSplitter {
1549    fn set_batch(&mut self, batch: RecordBatch) {
1550        self.batch = Some(batch);
1551        self.row_index = 0;
1552    }
1553
1554    fn next(&mut self) -> Option<(RecordBatch, bool)> {
1555        let Some(batch) = &self.batch else {
1556            return None;
1557        };
1558
1559        let remaining_rows = batch.num_rows() - self.row_index;
1560        let rows_to_slice = remaining_rows.min(self.batch_size);
1561        let sliced_batch = batch.slice(self.row_index, rows_to_slice);
1562        self.row_index += rows_to_slice;
1563
1564        let mut last = false;
1565        if self.row_index >= batch.num_rows() {
1566            self.batch = None;
1567            last = true;
1568        }
1569
1570        Some((sliced_batch, last))
1571    }
1572}
1573
1574/// When the order of the join inputs are changed, the output order of columns
1575/// must remain the same.
1576///
1577/// Joins output columns from their left input followed by their right input.
1578/// Thus if the inputs are reordered, the output columns must be reordered to
1579/// match the original order.
1580pub fn reorder_output_after_swap(
1581    plan: Arc<dyn ExecutionPlan>,
1582    left_schema: &Schema,
1583    right_schema: &Schema,
1584) -> Result<Arc<dyn ExecutionPlan>> {
1585    let proj = ProjectionExec::try_new(
1586        swap_reverting_projection(left_schema, right_schema),
1587        plan,
1588    )?;
1589    Ok(Arc::new(proj))
1590}
1591
1592/// When the order of the join is changed, the output order of columns must
1593/// remain the same.
1594///
1595/// Returns the expressions that will allow to swap back the values from the
1596/// original left as the first columns and those on the right next.
1597fn swap_reverting_projection(
1598    left_schema: &Schema,
1599    right_schema: &Schema,
1600) -> Vec<ProjectionExpr> {
1601    let right_cols =
1602        right_schema
1603            .fields()
1604            .iter()
1605            .enumerate()
1606            .map(|(i, f)| ProjectionExpr {
1607                expr: Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
1608                alias: f.name().to_owned(),
1609            });
1610    let right_len = right_cols.len();
1611    let left_cols =
1612        left_schema
1613            .fields()
1614            .iter()
1615            .enumerate()
1616            .map(|(i, f)| ProjectionExpr {
1617                expr: Arc::new(Column::new(f.name(), right_len + i))
1618                    as Arc<dyn PhysicalExpr>,
1619                alias: f.name().to_owned(),
1620            });
1621
1622    left_cols.chain(right_cols).collect()
1623}
1624
1625/// This function swaps the given join's projection.
1626pub fn swap_join_projection(
1627    left_schema_len: usize,
1628    right_schema_len: usize,
1629    projection: Option<&Vec<usize>>,
1630    join_type: &JoinType,
1631) -> Option<Vec<usize>> {
1632    match join_type {
1633        // For Anti/Semi join types, projection should remain unmodified,
1634        // since these joins output schema remains the same after swap
1635        JoinType::LeftAnti
1636        | JoinType::LeftSemi
1637        | JoinType::RightAnti
1638        | JoinType::RightSemi
1639        | JoinType::LeftMark
1640        | JoinType::RightMark => projection.cloned(),
1641        _ => projection.map(|p| {
1642            p.iter()
1643                .map(|i| {
1644                    // If the index is less than the left schema length, it is from
1645                    // the left schema, so we add the right schema length to it.
1646                    // Otherwise, it is from the right schema, so we subtract the left
1647                    // schema length from it.
1648                    if *i < left_schema_len {
1649                        *i + right_schema_len
1650                    } else {
1651                        *i - left_schema_len
1652                    }
1653                })
1654                .collect()
1655        }),
1656    }
1657}
1658
1659/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on`
1660/// using `offset` as a start value for `batch` row indices.
1661///
1662/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap,
1663/// which allows to keep either first (if set to true) or last (if set to false) row index
1664/// as a chain head for rows with equal hash values.
1665#[allow(clippy::too_many_arguments)]
1666pub fn update_hash(
1667    on: &[PhysicalExprRef],
1668    batch: &RecordBatch,
1669    hash_map: &mut dyn JoinHashMapType,
1670    offset: usize,
1671    random_state: &RandomState,
1672    hashes_buffer: &mut Vec<u64>,
1673    deleted_offset: usize,
1674    fifo_hashmap: bool,
1675) -> Result<()> {
1676    // evaluate the keys
1677    let keys_values = on
1678        .iter()
1679        .map(|c| c.evaluate(batch)?.into_array(batch.num_rows()))
1680        .collect::<Result<Vec<_>>>()?;
1681
1682    // calculate the hash values
1683    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1684
1685    // For usual JoinHashmap, the implementation is void.
1686    hash_map.extend_zero(batch.num_rows());
1687
1688    // Updating JoinHashMap from hash values iterator
1689    let hash_values_iter = hash_values
1690        .iter()
1691        .enumerate()
1692        .map(|(i, val)| (i + offset, val));
1693
1694    if fifo_hashmap {
1695        hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset);
1696    } else {
1697        hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset);
1698    }
1699
1700    Ok(())
1701}
1702
1703pub(super) fn equal_rows_arr(
1704    indices_left: &UInt64Array,
1705    indices_right: &UInt32Array,
1706    left_arrays: &[ArrayRef],
1707    right_arrays: &[ArrayRef],
1708    null_equality: NullEquality,
1709) -> Result<(UInt64Array, UInt32Array)> {
1710    let mut iter = left_arrays.iter().zip(right_arrays.iter());
1711
1712    let Some((first_left, first_right)) = iter.next() else {
1713        return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into()));
1714    };
1715
1716    let arr_left = take(first_left.as_ref(), indices_left, None)?;
1717    let arr_right = take(first_right.as_ref(), indices_right, None)?;
1718
1719    let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?;
1720
1721    // Use map and try_fold to iterate over the remaining pairs of arrays.
1722    // In each iteration, take is used on the pair of arrays and their equality is determined.
1723    // The results are then folded (combined) using the and function to get a final equality result.
1724    equal = iter
1725        .map(|(left, right)| {
1726            let arr_left = take(left.as_ref(), indices_left, None)?;
1727            let arr_right = take(right.as_ref(), indices_right, None)?;
1728            eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality)
1729        })
1730        .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
1731
1732    let filter_builder = FilterBuilder::new(&equal).optimize().build();
1733
1734    let left_filtered = filter_builder.filter(indices_left)?;
1735    let right_filtered = filter_builder.filter(indices_right)?;
1736
1737    Ok((
1738        downcast_array(left_filtered.as_ref()),
1739        downcast_array(right_filtered.as_ref()),
1740    ))
1741}
1742
1743// version of eq_dyn supporting equality on null arrays
1744fn eq_dyn_null(
1745    left: &dyn Array,
1746    right: &dyn Array,
1747    null_equality: NullEquality,
1748) -> Result<BooleanArray, ArrowError> {
1749    // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special
1750    // implementation
1751    // <https://github.com/apache/datafusion/issues/10749>
1752    if left.data_type().is_nested() {
1753        let op = match null_equality {
1754            NullEquality::NullEqualsNothing => Operator::Eq,
1755            NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
1756        };
1757        return Ok(compare_op_for_nested(op, &left, &right)?);
1758    }
1759    match null_equality {
1760        NullEquality::NullEqualsNothing => eq(&left, &right),
1761        NullEquality::NullEqualsNull => not_distinct(&left, &right),
1762    }
1763}
1764
1765/// Get comparison result of two rows of join arrays
1766pub fn compare_join_arrays(
1767    left_arrays: &[ArrayRef],
1768    left: usize,
1769    right_arrays: &[ArrayRef],
1770    right: usize,
1771    sort_options: &[SortOptions],
1772    null_equality: NullEquality,
1773) -> Result<Ordering> {
1774    let mut res = Ordering::Equal;
1775    for ((left_array, right_array), sort_options) in
1776        left_arrays.iter().zip(right_arrays).zip(sort_options)
1777    {
1778        macro_rules! compare_value {
1779            ($T:ty) => {{
1780                let left_array = left_array.as_any().downcast_ref::<$T>().unwrap();
1781                let right_array = right_array.as_any().downcast_ref::<$T>().unwrap();
1782                match (left_array.is_null(left), right_array.is_null(right)) {
1783                    (false, false) => {
1784                        let left_value = &left_array.value(left);
1785                        let right_value = &right_array.value(right);
1786                        res = left_value.partial_cmp(right_value).unwrap();
1787                        if sort_options.descending {
1788                            res = res.reverse();
1789                        }
1790                    }
1791                    (true, false) => {
1792                        res = if sort_options.nulls_first {
1793                            Ordering::Less
1794                        } else {
1795                            Ordering::Greater
1796                        };
1797                    }
1798                    (false, true) => {
1799                        res = if sort_options.nulls_first {
1800                            Ordering::Greater
1801                        } else {
1802                            Ordering::Less
1803                        };
1804                    }
1805                    _ => {
1806                        res = match null_equality {
1807                            NullEquality::NullEqualsNothing => Ordering::Less,
1808                            NullEquality::NullEqualsNull => Ordering::Equal,
1809                        };
1810                    }
1811                }
1812            }};
1813        }
1814
1815        match left_array.data_type() {
1816            DataType::Null => {}
1817            DataType::Boolean => compare_value!(BooleanArray),
1818            DataType::Int8 => compare_value!(Int8Array),
1819            DataType::Int16 => compare_value!(Int16Array),
1820            DataType::Int32 => compare_value!(Int32Array),
1821            DataType::Int64 => compare_value!(Int64Array),
1822            DataType::UInt8 => compare_value!(UInt8Array),
1823            DataType::UInt16 => compare_value!(UInt16Array),
1824            DataType::UInt32 => compare_value!(UInt32Array),
1825            DataType::UInt64 => compare_value!(UInt64Array),
1826            DataType::Float32 => compare_value!(Float32Array),
1827            DataType::Float64 => compare_value!(Float64Array),
1828            DataType::Binary => compare_value!(BinaryArray),
1829            DataType::BinaryView => compare_value!(BinaryViewArray),
1830            DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
1831            DataType::LargeBinary => compare_value!(LargeBinaryArray),
1832            DataType::Utf8 => compare_value!(StringArray),
1833            DataType::Utf8View => compare_value!(StringViewArray),
1834            DataType::LargeUtf8 => compare_value!(LargeStringArray),
1835            DataType::Decimal128(..) => compare_value!(Decimal128Array),
1836            DataType::Timestamp(time_unit, None) => match time_unit {
1837                TimeUnit::Second => compare_value!(TimestampSecondArray),
1838                TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
1839                TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
1840                TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
1841            },
1842            DataType::Date32 => compare_value!(Date32Array),
1843            DataType::Date64 => compare_value!(Date64Array),
1844            dt => {
1845                return not_impl_err!(
1846                    "Unsupported data type in sort merge join comparator: {}",
1847                    dt
1848                );
1849            }
1850        }
1851        if !res.is_eq() {
1852            break;
1853        }
1854    }
1855    Ok(res)
1856}
1857
1858#[cfg(test)]
1859mod tests {
1860    use std::collections::HashMap;
1861    use std::pin::Pin;
1862
1863    use super::*;
1864
1865    use arrow::array::Int32Array;
1866    use arrow::datatypes::{DataType, Fields};
1867    use arrow::error::{ArrowError, Result as ArrowResult};
1868    use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
1869    use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
1870    use datafusion_physical_expr::PhysicalSortExpr;
1871
1872    use rstest::rstest;
1873
1874    fn check(
1875        left: &[Column],
1876        right: &[Column],
1877        on: &[(PhysicalExprRef, PhysicalExprRef)],
1878    ) -> Result<()> {
1879        let left = left
1880            .iter()
1881            .map(|x| x.to_owned())
1882            .collect::<HashSet<Column>>();
1883        let right = right
1884            .iter()
1885            .map(|x| x.to_owned())
1886            .collect::<HashSet<Column>>();
1887        check_join_set_is_valid(&left, &right, on)
1888    }
1889
1890    #[test]
1891    fn check_valid() -> Result<()> {
1892        let left = vec![Column::new("a", 0), Column::new("b1", 1)];
1893        let right = vec![Column::new("a", 0), Column::new("b2", 1)];
1894        let on = &[(
1895            Arc::new(Column::new("a", 0)) as _,
1896            Arc::new(Column::new("a", 0)) as _,
1897        )];
1898
1899        check(&left, &right, on)?;
1900        Ok(())
1901    }
1902
1903    #[test]
1904    fn check_not_in_right() {
1905        let left = vec![Column::new("a", 0), Column::new("b", 1)];
1906        let right = vec![Column::new("b", 0)];
1907        let on = &[(
1908            Arc::new(Column::new("a", 0)) as _,
1909            Arc::new(Column::new("a", 0)) as _,
1910        )];
1911
1912        assert!(check(&left, &right, on).is_err());
1913    }
1914
1915    #[tokio::test]
1916    async fn check_error_nesting() {
1917        let once_fut = OnceFut::<()>::new(async {
1918            arrow_err!(ArrowError::CsvError("some error".to_string()))
1919        });
1920
1921        struct TestFut(OnceFut<()>);
1922        impl Future for TestFut {
1923            type Output = ArrowResult<()>;
1924
1925            fn poll(
1926                mut self: Pin<&mut Self>,
1927                cx: &mut Context<'_>,
1928            ) -> Poll<Self::Output> {
1929                match ready!(self.0.get(cx)) {
1930                    Ok(()) => Poll::Ready(Ok(())),
1931                    Err(e) => Poll::Ready(Err(e.into())),
1932                }
1933            }
1934        }
1935
1936        let res = TestFut(once_fut).await;
1937        let arrow_err_from_fut = res.expect_err("once_fut always return error");
1938
1939        let wrapped_err = DataFusionError::from(arrow_err_from_fut);
1940        let root_err = wrapped_err.find_root();
1941
1942        let _expected =
1943            arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned()));
1944
1945        assert!(matches!(root_err, _expected))
1946    }
1947
1948    #[test]
1949    fn check_not_in_left() {
1950        let left = vec![Column::new("b", 0)];
1951        let right = vec![Column::new("a", 0)];
1952        let on = &[(
1953            Arc::new(Column::new("a", 0)) as _,
1954            Arc::new(Column::new("a", 0)) as _,
1955        )];
1956
1957        assert!(check(&left, &right, on).is_err());
1958    }
1959
1960    #[test]
1961    fn check_collision() {
1962        // column "a" would appear both in left and right
1963        let left = vec![Column::new("a", 0), Column::new("c", 1)];
1964        let right = vec![Column::new("a", 0), Column::new("b", 1)];
1965        let on = &[(
1966            Arc::new(Column::new("a", 0)) as _,
1967            Arc::new(Column::new("b", 1)) as _,
1968        )];
1969
1970        assert!(check(&left, &right, on).is_ok());
1971    }
1972
1973    #[test]
1974    fn check_in_right() {
1975        let left = vec![Column::new("a", 0), Column::new("c", 1)];
1976        let right = vec![Column::new("b", 0)];
1977        let on = &[(
1978            Arc::new(Column::new("a", 0)) as _,
1979            Arc::new(Column::new("b", 0)) as _,
1980        )];
1981
1982        assert!(check(&left, &right, on).is_ok());
1983    }
1984
1985    #[test]
1986    fn test_join_schema() -> Result<()> {
1987        let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1988        let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1989        let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
1990        let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
1991
1992        let cases = vec![
1993            (&a, &b, JoinType::Inner, &a, &b),
1994            (&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
1995            (&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
1996            (&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
1997            // right input of a `LEFT` join can be null, regardless of input nullness
1998            (&a, &b, JoinType::Left, &a, &b_nulls),
1999            (&a, &b_nulls, JoinType::Left, &a, &b_nulls),
2000            (&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
2001            (&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
2002            // left input of a `RIGHT` join can be null, regardless of input nullness
2003            (&a, &b, JoinType::Right, &a_nulls, &b),
2004            (&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
2005            (&a_nulls, &b, JoinType::Right, &a_nulls, &b),
2006            (&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
2007            // Either input of a `FULL` join can be null
2008            (&a, &b, JoinType::Full, &a_nulls, &b_nulls),
2009            (&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
2010            (&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
2011            (&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
2012        ];
2013
2014        for (left_in, right_in, join_type, left_out, right_out) in cases {
2015            let (schema, _) = build_join_schema(left_in, right_in, &join_type);
2016
2017            let expected_fields = left_out
2018                .fields()
2019                .iter()
2020                .cloned()
2021                .chain(right_out.fields().iter().cloned())
2022                .collect::<Fields>();
2023
2024            let expected_schema = Schema::new(expected_fields);
2025            assert_eq!(
2026                schema,
2027                expected_schema,
2028                "Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
2029                left_in.fields()[0].name(),
2030                left_in.fields()[0].is_nullable(),
2031                right_in.fields()[0].name(),
2032                right_in.fields()[0].is_nullable(),
2033                join_type
2034            );
2035        }
2036
2037        Ok(())
2038    }
2039
2040    fn create_stats(
2041        num_rows: Option<usize>,
2042        column_stats: Vec<ColumnStatistics>,
2043        is_exact: bool,
2044    ) -> Statistics {
2045        Statistics {
2046            num_rows: if is_exact {
2047                num_rows.map(Exact)
2048            } else {
2049                num_rows.map(Inexact)
2050            }
2051            .unwrap_or(Absent),
2052            column_statistics: column_stats,
2053            total_byte_size: Absent,
2054        }
2055    }
2056
2057    fn create_column_stats(
2058        min: Precision<i64>,
2059        max: Precision<i64>,
2060        distinct_count: Precision<usize>,
2061        null_count: Precision<usize>,
2062    ) -> ColumnStatistics {
2063        ColumnStatistics {
2064            distinct_count,
2065            min_value: min.map(ScalarValue::from),
2066            max_value: max.map(ScalarValue::from),
2067            sum_value: Absent,
2068            null_count,
2069        }
2070    }
2071
2072    type PartialStats = (
2073        usize,
2074        Precision<i64>,
2075        Precision<i64>,
2076        Precision<usize>,
2077        Precision<usize>,
2078    );
2079
2080    // This is mainly for validating the all edge cases of the estimation, but
2081    // more advanced (and real world test cases) are below where we need some control
2082    // over the expected output (since it depends on join type to join type).
2083    #[test]
2084    fn test_inner_join_cardinality_single_column() -> Result<()> {
2085        let cases: Vec<(PartialStats, PartialStats, Option<Precision<usize>>)> = vec![
2086            // ------------------------------------------------
2087            // | left(rows, min, max, distinct, null_count),  |
2088            // | right(rows, min, max, distinct, null_count), |
2089            // | expected,                                    |
2090            // ------------------------------------------------
2091
2092            // Cardinality computation
2093            // =======================
2094            //
2095            // distinct(left) == NaN, distinct(right) == NaN
2096            (
2097                (10, Inexact(1), Inexact(10), Absent, Absent),
2098                (10, Inexact(1), Inexact(10), Absent, Absent),
2099                Some(Inexact(10)),
2100            ),
2101            // range(left) > range(right)
2102            (
2103                (10, Inexact(6), Inexact(10), Absent, Absent),
2104                (10, Inexact(8), Inexact(10), Absent, Absent),
2105                Some(Inexact(20)),
2106            ),
2107            // range(right) > range(left)
2108            (
2109                (10, Inexact(8), Inexact(10), Absent, Absent),
2110                (10, Inexact(6), Inexact(10), Absent, Absent),
2111                Some(Inexact(20)),
2112            ),
2113            // range(left) > len(left), range(right) > len(right)
2114            (
2115                (10, Inexact(1), Inexact(15), Absent, Absent),
2116                (20, Inexact(1), Inexact(40), Absent, Absent),
2117                Some(Inexact(10)),
2118            ),
2119            // Distinct count matches the range
2120            (
2121                (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2122                (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2123                Some(Inexact(10)),
2124            ),
2125            // Distinct count takes precedence over the range
2126            (
2127                (10, Inexact(1), Inexact(3), Inexact(10), Absent),
2128                (10, Inexact(1), Inexact(3), Inexact(10), Absent),
2129                Some(Inexact(10)),
2130            ),
2131            // distinct(left) > distinct(right)
2132            (
2133                (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2134                (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2135                Some(Inexact(20)),
2136            ),
2137            // distinct(right) > distinct(left)
2138            (
2139                (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2140                (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2141                Some(Inexact(20)),
2142            ),
2143            // min(left) < 0 (range(left) > range(right))
2144            (
2145                (10, Inexact(-5), Inexact(5), Absent, Absent),
2146                (10, Inexact(1), Inexact(5), Absent, Absent),
2147                Some(Inexact(10)),
2148            ),
2149            // min(right) < 0, max(right) < 0 (range(right) > range(left))
2150            (
2151                (10, Inexact(-25), Inexact(-20), Absent, Absent),
2152                (10, Inexact(-25), Inexact(-15), Absent, Absent),
2153                Some(Inexact(10)),
2154            ),
2155            // range(left) < 0, range(right) >= 0
2156            // (there isn't a case where both left and right ranges are negative
2157            //  so one of them is always going to work, this just proves negative
2158            //  ranges with bigger absolute values are not are not accidentally used).
2159            (
2160                (10, Inexact(-10), Inexact(0), Absent, Absent),
2161                (10, Inexact(0), Inexact(10), Inexact(5), Absent),
2162                Some(Inexact(10)),
2163            ),
2164            // range(left) = 1, range(right) = 1
2165            (
2166                (10, Inexact(1), Inexact(1), Absent, Absent),
2167                (10, Inexact(1), Inexact(1), Absent, Absent),
2168                Some(Inexact(100)),
2169            ),
2170            //
2171            // Edge cases
2172            // ==========
2173            //
2174            // No column level stats, fall back to row count.
2175            (
2176                (10, Absent, Absent, Absent, Absent),
2177                (10, Absent, Absent, Absent, Absent),
2178                Some(Inexact(10)),
2179            ),
2180            // No min or max (or both), but distinct available.
2181            (
2182                (10, Absent, Absent, Inexact(3), Absent),
2183                (10, Absent, Absent, Inexact(3), Absent),
2184                Some(Inexact(33)),
2185            ),
2186            (
2187                (10, Inexact(2), Absent, Inexact(3), Absent),
2188                (10, Absent, Inexact(5), Inexact(3), Absent),
2189                Some(Inexact(33)),
2190            ),
2191            (
2192                (10, Absent, Inexact(3), Inexact(3), Absent),
2193                (10, Inexact(1), Absent, Inexact(3), Absent),
2194                Some(Inexact(33)),
2195            ),
2196            // No min or max, fall back to row count
2197            (
2198                (10, Absent, Inexact(3), Absent, Absent),
2199                (10, Inexact(1), Absent, Absent, Absent),
2200                Some(Inexact(10)),
2201            ),
2202            // Non overlapping min/max (when exact=False).
2203            (
2204                (10, Absent, Inexact(4), Absent, Absent),
2205                (10, Inexact(5), Absent, Absent, Absent),
2206                Some(Inexact(0)),
2207            ),
2208            (
2209                (10, Inexact(0), Inexact(10), Absent, Absent),
2210                (10, Inexact(11), Inexact(20), Absent, Absent),
2211                Some(Inexact(0)),
2212            ),
2213            (
2214                (10, Inexact(11), Inexact(20), Absent, Absent),
2215                (10, Inexact(0), Inexact(10), Absent, Absent),
2216                Some(Inexact(0)),
2217            ),
2218            // distinct(left) = 0, distinct(right) = 0
2219            (
2220                (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2221                (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2222                None,
2223            ),
2224            // Inexact row count < exact null count with absent distinct count
2225            (
2226                (0, Inexact(1), Inexact(10), Absent, Exact(5)),
2227                (10, Inexact(1), Inexact(10), Absent, Absent),
2228                Some(Inexact(0)),
2229            ),
2230        ];
2231
2232        for (left_info, right_info, expected_cardinality) in cases {
2233            let left_num_rows = left_info.0;
2234            let left_col_stats = vec![create_column_stats(
2235                left_info.1,
2236                left_info.2,
2237                left_info.3,
2238                left_info.4,
2239            )];
2240
2241            let right_num_rows = right_info.0;
2242            let right_col_stats = vec![create_column_stats(
2243                right_info.1,
2244                right_info.2,
2245                right_info.3,
2246                right_info.4,
2247            )];
2248
2249            assert_eq!(
2250                estimate_inner_join_cardinality(
2251                    Statistics {
2252                        num_rows: Inexact(left_num_rows),
2253                        total_byte_size: Absent,
2254                        column_statistics: left_col_stats.clone(),
2255                    },
2256                    Statistics {
2257                        num_rows: Inexact(right_num_rows),
2258                        total_byte_size: Absent,
2259                        column_statistics: right_col_stats.clone(),
2260                    },
2261                ),
2262                expected_cardinality.clone()
2263            );
2264
2265            // We should also be able to use join_cardinality to get the same results
2266            let join_type = JoinType::Inner;
2267            let join_on = vec![(
2268                Arc::new(Column::new("a", 0)) as _,
2269                Arc::new(Column::new("b", 0)) as _,
2270            )];
2271            let partial_join_stats = estimate_join_cardinality(
2272                &join_type,
2273                create_stats(Some(left_num_rows), left_col_stats.clone(), false),
2274                create_stats(Some(right_num_rows), right_col_stats.clone(), false),
2275                &join_on,
2276            );
2277
2278            assert_eq!(
2279                partial_join_stats.clone().map(|s| Inexact(s.num_rows)),
2280                expected_cardinality.clone()
2281            );
2282            assert_eq!(
2283                partial_join_stats.map(|s| s.column_statistics),
2284                expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat())
2285            );
2286        }
2287        Ok(())
2288    }
2289
2290    #[test]
2291    fn test_inner_join_cardinality_multiple_column() -> Result<()> {
2292        let left_col_stats = vec![
2293            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2294            create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
2295        ];
2296
2297        let right_col_stats = vec![
2298            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2299            create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
2300        ];
2301
2302        // We have statistics about 4 columns, where the highest distinct
2303        // count is 200, so we are going to pick it.
2304        assert_eq!(
2305            estimate_inner_join_cardinality(
2306                Statistics {
2307                    num_rows: Inexact(400),
2308                    total_byte_size: Absent,
2309                    column_statistics: left_col_stats,
2310                },
2311                Statistics {
2312                    num_rows: Inexact(400),
2313                    total_byte_size: Absent,
2314                    column_statistics: right_col_stats,
2315                },
2316            ),
2317            Some(Inexact((400 * 400) / 200))
2318        );
2319        Ok(())
2320    }
2321
2322    #[test]
2323    fn test_inner_join_cardinality_decimal_range() -> Result<()> {
2324        let left_col_stats = vec![ColumnStatistics {
2325            distinct_count: Absent,
2326            min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)),
2327            max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)),
2328            ..Default::default()
2329        }];
2330
2331        let right_col_stats = vec![ColumnStatistics {
2332            distinct_count: Absent,
2333            min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)),
2334            max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)),
2335            ..Default::default()
2336        }];
2337
2338        assert_eq!(
2339            estimate_inner_join_cardinality(
2340                Statistics {
2341                    num_rows: Inexact(100),
2342                    total_byte_size: Absent,
2343                    column_statistics: left_col_stats,
2344                },
2345                Statistics {
2346                    num_rows: Inexact(100),
2347                    total_byte_size: Absent,
2348                    column_statistics: right_col_stats,
2349                },
2350            ),
2351            Some(Inexact(100))
2352        );
2353        Ok(())
2354    }
2355
2356    #[test]
2357    fn test_join_cardinality() -> Result<()> {
2358        // Left table (rows=1000)
2359        //   a: min=0, max=100, distinct=100
2360        //   b: min=0, max=500, distinct=500
2361        //   x: min=1000, max=10000, distinct=None
2362        //
2363        // Right table (rows=2000)
2364        //   c: min=0, max=100, distinct=50
2365        //   d: min=0, max=2000, distinct=2500 (how? some inexact statistics)
2366        //   y: min=0, max=100, distinct=None
2367        //
2368        // Join on a=c, b=d (ignore x/y)
2369        let cases = vec![
2370            (JoinType::Inner, 800),
2371            (JoinType::Left, 1000),
2372            (JoinType::Right, 2000),
2373            (JoinType::Full, 2200),
2374        ];
2375
2376        let left_col_stats = vec![
2377            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2378            create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2379            create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2380        ];
2381
2382        let right_col_stats = vec![
2383            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2384            create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2385            create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2386        ];
2387
2388        for (join_type, expected_num_rows) in cases {
2389            let join_on = vec![
2390                (
2391                    Arc::new(Column::new("a", 0)) as _,
2392                    Arc::new(Column::new("c", 0)) as _,
2393                ),
2394                (
2395                    Arc::new(Column::new("b", 1)) as _,
2396                    Arc::new(Column::new("d", 1)) as _,
2397                ),
2398            ];
2399
2400            let partial_join_stats = estimate_join_cardinality(
2401                &join_type,
2402                create_stats(Some(1000), left_col_stats.clone(), false),
2403                create_stats(Some(2000), right_col_stats.clone(), false),
2404                &join_on,
2405            )
2406            .unwrap();
2407            assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2408            assert_eq!(
2409                partial_join_stats.column_statistics,
2410                [left_col_stats.clone(), right_col_stats.clone()].concat()
2411            );
2412        }
2413
2414        Ok(())
2415    }
2416
2417    #[test]
2418    fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
2419        // Left table (rows=1000)
2420        //   a: min=0, max=100, distinct=100
2421        //   b: min=0, max=500, distinct=500
2422        //   x: min=1000, max=10000, distinct=None
2423        //
2424        // Right table (rows=2000)
2425        //   c: min=0, max=100, distinct=50
2426        //   d: min=0, max=2000, distinct=2500 (how? some inexact statistics)
2427        //   y: min=0, max=100, distinct=None
2428        //
2429        // Join on a=c, x=y (ignores b/d) where x and y does not intersect
2430
2431        let left_col_stats = vec![
2432            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2433            create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2434            create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2435        ];
2436
2437        let right_col_stats = vec![
2438            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2439            create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2440            create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2441        ];
2442
2443        let join_on = vec![
2444            (
2445                Arc::new(Column::new("a", 0)) as _,
2446                Arc::new(Column::new("c", 0)) as _,
2447            ),
2448            (
2449                Arc::new(Column::new("x", 2)) as _,
2450                Arc::new(Column::new("y", 2)) as _,
2451            ),
2452        ];
2453
2454        let cases = vec![
2455            // Join type, expected cardinality
2456            //
2457            // When an inner join is disjoint, that means it won't
2458            // produce any rows.
2459            (JoinType::Inner, 0),
2460            // But left/right outer joins will produce at least
2461            // the amount of rows from the left/right side.
2462            (JoinType::Left, 1000),
2463            (JoinType::Right, 2000),
2464            // And a full outer join will produce at least the combination
2465            // of the rows above (minus the cardinality of the inner join, which
2466            // is 0).
2467            (JoinType::Full, 3000),
2468        ];
2469
2470        for (join_type, expected_num_rows) in cases {
2471            let partial_join_stats = estimate_join_cardinality(
2472                &join_type,
2473                create_stats(Some(1000), left_col_stats.clone(), true),
2474                create_stats(Some(2000), right_col_stats.clone(), true),
2475                &join_on,
2476            )
2477            .unwrap();
2478            assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2479            assert_eq!(
2480                partial_join_stats.column_statistics,
2481                [left_col_stats.clone(), right_col_stats.clone()].concat()
2482            );
2483        }
2484
2485        Ok(())
2486    }
2487
2488    #[test]
2489    fn test_anti_semi_join_cardinality() -> Result<()> {
2490        let cases: Vec<(JoinType, PartialStats, PartialStats, Option<usize>)> = vec![
2491            // ------------------------------------------------
2492            // | join_type ,                                   |
2493            // | left(rows, min, max, distinct, null_count), |
2494            // | right(rows, min, max, distinct, null_count), |
2495            // | expected,                                    |
2496            // ------------------------------------------------
2497
2498            // Cardinality computation
2499            // =======================
2500            (
2501                JoinType::LeftSemi,
2502                (50, Inexact(10), Inexact(20), Absent, Absent),
2503                (10, Inexact(15), Inexact(25), Absent, Absent),
2504                Some(50),
2505            ),
2506            (
2507                JoinType::RightSemi,
2508                (50, Inexact(10), Inexact(20), Absent, Absent),
2509                (10, Inexact(15), Inexact(25), Absent, Absent),
2510                Some(10),
2511            ),
2512            (
2513                JoinType::LeftSemi,
2514                (10, Absent, Absent, Absent, Absent),
2515                (50, Absent, Absent, Absent, Absent),
2516                Some(10),
2517            ),
2518            (
2519                JoinType::LeftSemi,
2520                (50, Inexact(10), Inexact(20), Absent, Absent),
2521                (10, Inexact(30), Inexact(40), Absent, Absent),
2522                Some(0),
2523            ),
2524            (
2525                JoinType::LeftSemi,
2526                (50, Inexact(10), Absent, Absent, Absent),
2527                (10, Absent, Inexact(5), Absent, Absent),
2528                Some(0),
2529            ),
2530            (
2531                JoinType::LeftSemi,
2532                (50, Absent, Inexact(20), Absent, Absent),
2533                (10, Inexact(30), Absent, Absent, Absent),
2534                Some(0),
2535            ),
2536            (
2537                JoinType::LeftAnti,
2538                (50, Inexact(10), Inexact(20), Absent, Absent),
2539                (10, Inexact(15), Inexact(25), Absent, Absent),
2540                Some(50),
2541            ),
2542            (
2543                JoinType::RightAnti,
2544                (50, Inexact(10), Inexact(20), Absent, Absent),
2545                (10, Inexact(15), Inexact(25), Absent, Absent),
2546                Some(10),
2547            ),
2548            (
2549                JoinType::LeftAnti,
2550                (10, Absent, Absent, Absent, Absent),
2551                (50, Absent, Absent, Absent, Absent),
2552                Some(10),
2553            ),
2554            (
2555                JoinType::LeftAnti,
2556                (50, Inexact(10), Inexact(20), Absent, Absent),
2557                (10, Inexact(30), Inexact(40), Absent, Absent),
2558                Some(50),
2559            ),
2560            (
2561                JoinType::LeftAnti,
2562                (50, Inexact(10), Absent, Absent, Absent),
2563                (10, Absent, Inexact(5), Absent, Absent),
2564                Some(50),
2565            ),
2566            (
2567                JoinType::LeftAnti,
2568                (50, Absent, Inexact(20), Absent, Absent),
2569                (10, Inexact(30), Absent, Absent, Absent),
2570                Some(50),
2571            ),
2572        ];
2573
2574        let join_on = vec![(
2575            Arc::new(Column::new("l_col", 0)) as _,
2576            Arc::new(Column::new("r_col", 0)) as _,
2577        )];
2578
2579        for (join_type, outer_info, inner_info, expected) in cases {
2580            let outer_num_rows = outer_info.0;
2581            let outer_col_stats = vec![create_column_stats(
2582                outer_info.1,
2583                outer_info.2,
2584                outer_info.3,
2585                outer_info.4,
2586            )];
2587
2588            let inner_num_rows = inner_info.0;
2589            let inner_col_stats = vec![create_column_stats(
2590                inner_info.1,
2591                inner_info.2,
2592                inner_info.3,
2593                inner_info.4,
2594            )];
2595
2596            let output_cardinality = estimate_join_cardinality(
2597                &join_type,
2598                Statistics {
2599                    num_rows: Inexact(outer_num_rows),
2600                    total_byte_size: Absent,
2601                    column_statistics: outer_col_stats,
2602                },
2603                Statistics {
2604                    num_rows: Inexact(inner_num_rows),
2605                    total_byte_size: Absent,
2606                    column_statistics: inner_col_stats,
2607                },
2608                &join_on,
2609            )
2610            .map(|cardinality| cardinality.num_rows);
2611
2612            assert_eq!(
2613                output_cardinality, expected,
2614                "failure for join_type: {join_type}"
2615            );
2616        }
2617
2618        Ok(())
2619    }
2620
2621    #[test]
2622    fn test_semi_join_cardinality_absent_rows() -> Result<()> {
2623        let dummy_column_stats =
2624            vec![create_column_stats(Absent, Absent, Absent, Absent)];
2625        let join_on = vec![(
2626            Arc::new(Column::new("l_col", 0)) as _,
2627            Arc::new(Column::new("r_col", 0)) as _,
2628        )];
2629
2630        let absent_outer_estimation = estimate_join_cardinality(
2631            &JoinType::LeftSemi,
2632            Statistics {
2633                num_rows: Absent,
2634                total_byte_size: Absent,
2635                column_statistics: dummy_column_stats.clone(),
2636            },
2637            Statistics {
2638                num_rows: Exact(10),
2639                total_byte_size: Absent,
2640                column_statistics: dummy_column_stats.clone(),
2641            },
2642            &join_on,
2643        );
2644        assert!(
2645            absent_outer_estimation.is_none(),
2646            "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows"
2647        );
2648
2649        let absent_inner_estimation = estimate_join_cardinality(
2650            &JoinType::LeftSemi,
2651            Statistics {
2652                num_rows: Inexact(500),
2653                total_byte_size: Absent,
2654                column_statistics: dummy_column_stats.clone(),
2655            },
2656            Statistics {
2657                num_rows: Absent,
2658                total_byte_size: Absent,
2659                column_statistics: dummy_column_stats.clone(),
2660            },
2661            &join_on,
2662        ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows");
2663
2664        assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows");
2665
2666        let absent_inner_estimation = estimate_join_cardinality(
2667            &JoinType::LeftSemi,
2668            Statistics {
2669                num_rows: Absent,
2670                total_byte_size: Absent,
2671                column_statistics: dummy_column_stats.clone(),
2672            },
2673            Statistics {
2674                num_rows: Absent,
2675                total_byte_size: Absent,
2676                column_statistics: dummy_column_stats,
2677            },
2678            &join_on,
2679        );
2680        assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows");
2681
2682        Ok(())
2683    }
2684
2685    #[test]
2686    fn test_calculate_join_output_ordering() -> Result<()> {
2687        let left_ordering = LexOrdering::new(vec![
2688            PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2689            PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2690            PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2691        ]);
2692        let right_ordering = LexOrdering::new(vec![
2693            PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))),
2694            PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))),
2695        ]);
2696        let join_type = JoinType::Inner;
2697        let left_columns_len = 5;
2698        let maintains_input_orders = [[true, false], [false, true]];
2699        let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
2700
2701        let expected = [
2702            LexOrdering::new(vec![
2703                PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2704                PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2705                PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2706                PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2707                PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2708            ]),
2709            LexOrdering::new(vec![
2710                PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2711                PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2712                PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2713                PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2714                PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2715            ]),
2716        ];
2717
2718        for (i, (maintains_input_order, probe_side)) in
2719            maintains_input_orders.iter().zip(probe_sides).enumerate()
2720        {
2721            assert_eq!(
2722                calculate_join_output_ordering(
2723                    left_ordering.as_ref(),
2724                    right_ordering.as_ref(),
2725                    join_type,
2726                    left_columns_len,
2727                    maintains_input_order,
2728                    probe_side,
2729                )?,
2730                expected[i]
2731            );
2732        }
2733
2734        Ok(())
2735    }
2736
2737    fn create_test_batch(num_rows: usize) -> RecordBatch {
2738        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2739        let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32));
2740        RecordBatch::try_new(schema, vec![data]).unwrap()
2741    }
2742
2743    fn assert_split_batches(
2744        batches: Vec<(RecordBatch, bool)>,
2745        batch_size: usize,
2746        num_rows: usize,
2747    ) {
2748        let mut row_count = 0;
2749        for (batch, last) in batches.into_iter() {
2750            assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size));
2751            let column = batch
2752                .column(0)
2753                .as_any()
2754                .downcast_ref::<Int32Array>()
2755                .unwrap();
2756            for i in 0..batch.num_rows() {
2757                assert_eq!(column.value(i), i as i32 + row_count as i32);
2758            }
2759            row_count += batch.num_rows();
2760            assert_eq!(last, row_count == num_rows);
2761        }
2762    }
2763
2764    #[rstest]
2765    #[test]
2766    fn test_batch_splitter(
2767        #[values(1, 3, 11)] batch_size: usize,
2768        #[values(1, 6, 50)] num_rows: usize,
2769    ) {
2770        let mut splitter = BatchSplitter::new(batch_size);
2771        splitter.set_batch(create_test_batch(num_rows));
2772
2773        let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size));
2774        while let Some(batch) = splitter.next() {
2775            batches.push(batch);
2776        }
2777
2778        assert!(splitter.next().is_none());
2779        assert_split_batches(batches, batch_size, num_rows);
2780    }
2781
2782    #[tokio::test]
2783    async fn test_swap_reverting_projection() {
2784        let left_schema = Schema::new(vec![
2785            Field::new("a", DataType::Int32, false),
2786            Field::new("b", DataType::Int32, false),
2787        ]);
2788
2789        let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
2790
2791        let proj = swap_reverting_projection(&left_schema, &right_schema);
2792
2793        assert_eq!(proj.len(), 3);
2794
2795        let proj_expr = &proj[0];
2796        assert_eq!(proj_expr.alias, "a");
2797        assert_col_expr(&proj_expr.expr, "a", 1);
2798
2799        let proj_expr = &proj[1];
2800        assert_eq!(proj_expr.alias, "b");
2801        assert_col_expr(&proj_expr.expr, "b", 2);
2802
2803        let proj_expr = &proj[2];
2804        assert_eq!(proj_expr.alias, "c");
2805        assert_col_expr(&proj_expr.expr, "c", 0);
2806    }
2807
2808    fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
2809        let col = expr
2810            .as_any()
2811            .downcast_ref::<Column>()
2812            .expect("Projection items should be Column expression");
2813        assert_eq!(col.name(), name);
2814        assert_eq!(col.index(), index);
2815    }
2816
2817    #[test]
2818    fn test_join_metadata() -> Result<()> {
2819        let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)])
2820            .with_metadata(HashMap::from([("key".to_string(), "left".to_string())]));
2821
2822        let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)])
2823            .with_metadata(HashMap::from([("key".to_string(), "right".to_string())]));
2824
2825        let (join_schema, _) =
2826            build_join_schema(&left_schema, &right_schema, &JoinType::Left);
2827        assert_eq!(
2828            join_schema.metadata(),
2829            &HashMap::from([("key".to_string(), "left".to_string())])
2830        );
2831        let (join_schema, _) =
2832            build_join_schema(&left_schema, &right_schema, &JoinType::Right);
2833        assert_eq!(
2834            join_schema.metadata(),
2835            &HashMap::from([("key".to_string(), "right".to_string())])
2836        );
2837
2838        Ok(())
2839    }
2840}