datafusion_physical_plan/joins/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Join related functionality used both on logical and physical plans
19
20use std::cmp::min;
21use std::collections::HashSet;
22use std::fmt::{self, Debug};
23use std::future::Future;
24use std::iter::once;
25use std::ops::Range;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use crate::joins::SharedBitmapBuilder;
30use crate::metrics::{self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder};
31use crate::projection::{ProjectionExec, ProjectionExpr};
32use crate::{
33    ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics,
34};
35// compatibility
36pub use super::join_filter::JoinFilter;
37pub use super::join_hash_map::JoinHashMapType;
38pub use crate::joins::{JoinOn, JoinOnRef};
39
40use ahash::RandomState;
41use arrow::array::{
42    builder::UInt64Builder, downcast_array, new_null_array, Array, ArrowPrimitiveType,
43    BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions,
44    UInt32Array, UInt32Builder, UInt64Array,
45};
46use arrow::array::{ArrayRef, BooleanArray};
47use arrow::buffer::{BooleanBuffer, NullBuffer};
48use arrow::compute::kernels::cmp::eq;
49use arrow::compute::{self, and, take, FilterBuilder};
50use arrow::datatypes::{
51    ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type,
52};
53use arrow_ord::cmp::not_distinct;
54use arrow_schema::ArrowError;
55use datafusion_common::cast::as_boolean_array;
56use datafusion_common::hash_utils::create_hashes;
57use datafusion_common::stats::Precision;
58use datafusion_common::{
59    plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult,
60};
61use datafusion_expr::interval_arithmetic::Interval;
62use datafusion_expr::Operator;
63use datafusion_physical_expr::expressions::Column;
64use datafusion_physical_expr::utils::collect_columns;
65use datafusion_physical_expr::{
66    add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr,
67    PhysicalExprRef,
68};
69
70use datafusion_physical_expr_common::datum::compare_op_for_nested;
71use futures::future::{BoxFuture, Shared};
72use futures::{ready, FutureExt};
73use parking_lot::Mutex;
74
75/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join.
76/// They are valid whenever their columns' intersection equals the set `on`
77pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
78    let left: HashSet<Column> = left
79        .fields()
80        .iter()
81        .enumerate()
82        .map(|(idx, f)| Column::new(f.name(), idx))
83        .collect();
84    let right: HashSet<Column> = right
85        .fields()
86        .iter()
87        .enumerate()
88        .map(|(idx, f)| Column::new(f.name(), idx))
89        .collect();
90
91    check_join_set_is_valid(&left, &right, on)
92}
93
94/// Checks whether the sets left, right and on compose a valid join.
95/// They are valid whenever their intersection equals the set `on`
96fn check_join_set_is_valid(
97    left: &HashSet<Column>,
98    right: &HashSet<Column>,
99    on: &[(PhysicalExprRef, PhysicalExprRef)],
100) -> Result<()> {
101    let on_left = &on
102        .iter()
103        .flat_map(|on| collect_columns(&on.0))
104        .collect::<HashSet<_>>();
105    let left_missing = on_left.difference(left).collect::<HashSet<_>>();
106
107    let on_right = &on
108        .iter()
109        .flat_map(|on| collect_columns(&on.1))
110        .collect::<HashSet<_>>();
111    let right_missing = on_right.difference(right).collect::<HashSet<_>>();
112
113    if !left_missing.is_empty() | !right_missing.is_empty() {
114        return plan_err!(
115            "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}"
116        );
117    };
118
119    Ok(())
120}
121
122/// Adjust the right out partitioning to new Column Index
123pub fn adjust_right_output_partitioning(
124    right_partitioning: &Partitioning,
125    left_columns_len: usize,
126) -> Result<Partitioning> {
127    let result = match right_partitioning {
128        Partitioning::Hash(exprs, size) => {
129            let new_exprs = exprs
130                .iter()
131                .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _))
132                .collect::<Result<_>>()?;
133            Partitioning::Hash(new_exprs, *size)
134        }
135        result => result.clone(),
136    };
137    Ok(result)
138}
139
140/// Calculate the output ordering of a given join operation.
141pub fn calculate_join_output_ordering(
142    left_ordering: Option<&LexOrdering>,
143    right_ordering: Option<&LexOrdering>,
144    join_type: JoinType,
145    left_columns_len: usize,
146    maintains_input_order: &[bool],
147    probe_side: Option<JoinSide>,
148) -> Result<Option<LexOrdering>> {
149    match maintains_input_order {
150        [true, false] => {
151            // Special case, we can prefix ordering of right side with the ordering of left side.
152            if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) {
153                if let Some(right_ordering) = right_ordering.cloned() {
154                    let right_offset = add_offset_to_physical_sort_exprs(
155                        right_ordering,
156                        left_columns_len as _,
157                    )?;
158                    return if let Some(left_ordering) = left_ordering {
159                        let mut result = left_ordering.clone();
160                        result.extend(right_offset);
161                        Ok(Some(result))
162                    } else {
163                        Ok(LexOrdering::new(right_offset))
164                    };
165                }
166            }
167            Ok(left_ordering.cloned())
168        }
169        [false, true] => {
170            // Special case, we can prefix ordering of left side with the ordering of right side.
171            if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) {
172                return if let Some(right_ordering) = right_ordering.cloned() {
173                    let mut right_offset = add_offset_to_physical_sort_exprs(
174                        right_ordering,
175                        left_columns_len as _,
176                    )?;
177                    if let Some(left_ordering) = left_ordering {
178                        right_offset.extend(left_ordering.clone());
179                    }
180                    Ok(LexOrdering::new(right_offset))
181                } else {
182                    Ok(left_ordering.cloned())
183                };
184            }
185            let Some(right_ordering) = right_ordering else {
186                return Ok(None);
187            };
188            match join_type {
189                JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
190                    add_offset_to_physical_sort_exprs(
191                        right_ordering.clone(),
192                        left_columns_len as _,
193                    )
194                    .map(LexOrdering::new)
195                }
196                _ => Ok(Some(right_ordering.clone())),
197            }
198        }
199        // Doesn't maintain ordering, output ordering is None.
200        [false, false] => Ok(None),
201        [true, true] => unreachable!("Cannot maintain ordering of both sides"),
202        _ => unreachable!("Join operators can not have more than two children"),
203    }
204}
205
206/// Information about the index and placement (left or right) of the columns
207#[derive(Debug, Clone, PartialEq)]
208pub struct ColumnIndex {
209    /// Index of the column
210    pub index: usize,
211    /// Whether the column is at the left or right side
212    pub side: JoinSide,
213}
214
215/// Returns the output field given the input field. Outer joins may
216/// insert nulls even if the input was not null
217///
218fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
219    let force_nullable = match join_type {
220        JoinType::Inner => false,
221        JoinType::Left => !is_left, // right input is padded with nulls
222        JoinType::Right => is_left, // left input is padded with nulls
223        JoinType::Full => true,     // both inputs can be padded with nulls
224        JoinType::LeftSemi => false, // doesn't introduce nulls
225        JoinType::RightSemi => false, // doesn't introduce nulls
226        JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??)
227        JoinType::RightAnti => false, // doesn't introduce nulls (or can it??)
228        JoinType::LeftMark => false,
229        JoinType::RightMark => false,
230    };
231
232    if force_nullable {
233        old_field.clone().with_nullable(true)
234    } else {
235        old_field.clone()
236    }
237}
238
239/// Creates a schema for a join operation.
240/// The fields from the left side are first
241pub fn build_join_schema(
242    left: &Schema,
243    right: &Schema,
244    join_type: &JoinType,
245) -> (Schema, Vec<ColumnIndex>) {
246    let left_fields = || {
247        left.fields()
248            .iter()
249            .map(|f| output_join_field(f, join_type, true))
250            .enumerate()
251            .map(|(index, f)| {
252                (
253                    f,
254                    ColumnIndex {
255                        index,
256                        side: JoinSide::Left,
257                    },
258                )
259            })
260    };
261
262    let right_fields = || {
263        right
264            .fields()
265            .iter()
266            .map(|f| output_join_field(f, join_type, false))
267            .enumerate()
268            .map(|(index, f)| {
269                (
270                    f,
271                    ColumnIndex {
272                        index,
273                        side: JoinSide::Right,
274                    },
275                )
276            })
277    };
278
279    let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
280        JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
281            // left then right
282            left_fields().chain(right_fields()).unzip()
283        }
284        JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(),
285        JoinType::LeftMark => {
286            let right_field = once((
287                Field::new("mark", arrow::datatypes::DataType::Boolean, false),
288                ColumnIndex {
289                    index: 0,
290                    side: JoinSide::None,
291                },
292            ));
293            left_fields().chain(right_field).unzip()
294        }
295        JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(),
296        JoinType::RightMark => {
297            let left_field = once((
298                Field::new("mark", arrow_schema::DataType::Boolean, false),
299                ColumnIndex {
300                    index: 0,
301                    side: JoinSide::None,
302                },
303            ));
304            right_fields().chain(left_field).unzip()
305        }
306    };
307
308    let (schema1, schema2) = match join_type {
309        JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right),
310        _ => (right, left),
311    };
312
313    let metadata = schema1
314        .metadata()
315        .clone()
316        .into_iter()
317        .chain(schema2.metadata().clone())
318        .collect();
319
320    (fields.finish().with_metadata(metadata), column_indices)
321}
322
323/// A [`OnceAsync`] runs an `async` closure once, where multiple calls to
324/// [`OnceAsync::try_once`] return a [`OnceFut`] that resolves to the result of the
325/// same computation.
326///
327/// This is useful for joins where the results of one child are needed to proceed
328/// with multiple output stream
329///
330///
331/// For example, in a hash join, one input is buffered and shared across
332/// potentially multiple output partitions. Each output partition must wait for
333/// the hash table to be built before proceeding.
334///
335/// Each output partition waits on the same `OnceAsync` before proceeding.
336pub(crate) struct OnceAsync<T> {
337    fut: Mutex<Option<SharedResult<OnceFut<T>>>>,
338}
339
340impl<T> Default for OnceAsync<T> {
341    fn default() -> Self {
342        Self {
343            fut: Mutex::new(None),
344        }
345    }
346}
347
348impl<T> Debug for OnceAsync<T> {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        write!(f, "OnceAsync")
351    }
352}
353
354impl<T: 'static> OnceAsync<T> {
355    /// If this is the first call to this function on this object, will invoke
356    /// `f` to obtain a future and return a [`OnceFut`] referring to this. `f`
357    /// may fail, in which case its error is returned.
358    ///
359    /// If this is not the first call, will return a [`OnceFut`] referring
360    /// to the same future as was returned by the first call - or the same
361    /// error if the initial call to `f` failed.
362    pub(crate) fn try_once<F, Fut>(&self, f: F) -> Result<OnceFut<T>>
363    where
364        F: FnOnce() -> Result<Fut>,
365        Fut: Future<Output = Result<T>> + Send + 'static,
366    {
367        self.fut
368            .lock()
369            .get_or_insert_with(|| f().map(OnceFut::new).map_err(Arc::new))
370            .clone()
371            .map_err(DataFusionError::Shared)
372    }
373}
374
375/// The shared future type used internally within [`OnceAsync`]
376type OnceFutPending<T> = Shared<BoxFuture<'static, SharedResult<Arc<T>>>>;
377
378/// A [`OnceFut`] represents a shared asynchronous computation, that will be evaluated
379/// once for all [`Clone`]'s, with [`OnceFut::get`] providing a non-consuming interface
380/// to drive the underlying [`Future`] to completion
381pub(crate) struct OnceFut<T> {
382    state: OnceFutState<T>,
383}
384
385impl<T> Clone for OnceFut<T> {
386    fn clone(&self) -> Self {
387        Self {
388            state: self.state.clone(),
389        }
390    }
391}
392
393/// A shared state between statistic aggregators for a join
394/// operation.
395#[derive(Clone, Debug, Default)]
396struct PartialJoinStatistics {
397    pub num_rows: usize,
398    pub column_statistics: Vec<ColumnStatistics>,
399}
400
401/// Estimate the statistics for the given join's output.
402pub(crate) fn estimate_join_statistics(
403    left_stats: Statistics,
404    right_stats: Statistics,
405    on: JoinOn,
406    join_type: &JoinType,
407    schema: &Schema,
408) -> Result<Statistics> {
409    let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on);
410    let (num_rows, column_statistics) = match join_stats {
411        Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics),
412        None => (Precision::Absent, Statistics::unknown_column(schema)),
413    };
414    Ok(Statistics {
415        num_rows,
416        total_byte_size: Precision::Absent,
417        column_statistics,
418    })
419}
420
421// Estimate the cardinality for the given join with input statistics.
422fn estimate_join_cardinality(
423    join_type: &JoinType,
424    left_stats: Statistics,
425    right_stats: Statistics,
426    on: &JoinOn,
427) -> Option<PartialJoinStatistics> {
428    let (left_col_stats, right_col_stats) = on
429        .iter()
430        .map(|(left, right)| {
431            match (
432                left.as_any().downcast_ref::<Column>(),
433                right.as_any().downcast_ref::<Column>(),
434            ) {
435                (Some(left), Some(right)) => (
436                    left_stats.column_statistics[left.index()].clone(),
437                    right_stats.column_statistics[right.index()].clone(),
438                ),
439                _ => (
440                    ColumnStatistics::new_unknown(),
441                    ColumnStatistics::new_unknown(),
442                ),
443            }
444        })
445        .unzip::<_, _, Vec<_>, Vec<_>>();
446
447    match join_type {
448        JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
449            let ij_cardinality = estimate_inner_join_cardinality(
450                Statistics {
451                    num_rows: left_stats.num_rows,
452                    total_byte_size: Precision::Absent,
453                    column_statistics: left_col_stats,
454                },
455                Statistics {
456                    num_rows: right_stats.num_rows,
457                    total_byte_size: Precision::Absent,
458                    column_statistics: right_col_stats,
459                },
460            )?;
461
462            // The cardinality for inner join can also be used to estimate
463            // the cardinality of left/right/full outer joins as long as it
464            // it is greater than the minimum cardinality constraints of these
465            // joins (so that we don't underestimate the cardinality).
466            let cardinality = match join_type {
467                JoinType::Inner => ij_cardinality,
468                JoinType::Left => ij_cardinality.max(&left_stats.num_rows),
469                JoinType::Right => ij_cardinality.max(&right_stats.num_rows),
470                JoinType::Full => ij_cardinality
471                    .max(&left_stats.num_rows)
472                    .add(&ij_cardinality.max(&right_stats.num_rows))
473                    .sub(&ij_cardinality),
474                _ => unreachable!(),
475            };
476
477            Some(PartialJoinStatistics {
478                num_rows: *cardinality.get_value()?,
479                // We don't do anything specific here, just combine the existing
480                // statistics which might yield subpar results (although it is
481                // true, esp regarding min/max). For a better estimation, we need
482                // filter selectivity analysis first.
483                column_statistics: left_stats
484                    .column_statistics
485                    .into_iter()
486                    .chain(right_stats.column_statistics)
487                    .collect(),
488            })
489        }
490
491        // For SemiJoins estimation result is either zero, in cases when inputs
492        // are non-overlapping according to statistics, or equal to number of rows
493        // for outer input
494        JoinType::LeftSemi | JoinType::RightSemi => {
495            let (outer_stats, inner_stats) = match join_type {
496                JoinType::LeftSemi => (left_stats, right_stats),
497                _ => (right_stats, left_stats),
498            };
499            let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) {
500                Some(estimation) => *estimation.get_value()?,
501                None => *outer_stats.num_rows.get_value()?,
502            };
503
504            Some(PartialJoinStatistics {
505                num_rows: cardinality,
506                column_statistics: outer_stats.column_statistics,
507            })
508        }
509
510        // For AntiJoins estimation always equals to outer statistics, as
511        // non-overlapping inputs won't affect estimation
512        JoinType::LeftAnti | JoinType::RightAnti => {
513            let outer_stats = match join_type {
514                JoinType::LeftAnti => left_stats,
515                _ => right_stats,
516            };
517
518            Some(PartialJoinStatistics {
519                num_rows: *outer_stats.num_rows.get_value()?,
520                column_statistics: outer_stats.column_statistics,
521            })
522        }
523
524        JoinType::LeftMark => {
525            let num_rows = *left_stats.num_rows.get_value()?;
526            let mut column_statistics = left_stats.column_statistics;
527            column_statistics.push(ColumnStatistics::new_unknown());
528            Some(PartialJoinStatistics {
529                num_rows,
530                column_statistics,
531            })
532        }
533        JoinType::RightMark => {
534            let num_rows = *right_stats.num_rows.get_value()?;
535            let mut column_statistics = right_stats.column_statistics;
536            column_statistics.push(ColumnStatistics::new_unknown());
537            Some(PartialJoinStatistics {
538                num_rows,
539                column_statistics,
540            })
541        }
542    }
543}
544
545/// Estimate the inner join cardinality by using the basic building blocks of
546/// column-level statistics and the total row count. This is a very naive and
547/// a very conservative implementation that can quickly give up if there is not
548/// enough input statistics.
549fn estimate_inner_join_cardinality(
550    left_stats: Statistics,
551    right_stats: Statistics,
552) -> Option<Precision<usize>> {
553    // Immediately return if inputs considered as non-overlapping
554    if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) {
555        return Some(estimation);
556    };
557
558    // The algorithm here is partly based on the non-histogram selectivity estimation
559    // from Spark's Catalyst optimizer.
560    let mut join_selectivity = Precision::Absent;
561    for (left_stat, right_stat) in left_stats
562        .column_statistics
563        .iter()
564        .zip(right_stats.column_statistics.iter())
565    {
566        // Break if any of statistics bounds are undefined
567        if left_stat.min_value.get_value().is_none()
568            || left_stat.max_value.get_value().is_none()
569            || right_stat.min_value.get_value().is_none()
570            || right_stat.max_value.get_value().is_none()
571        {
572            return None;
573        }
574
575        let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat);
576        let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat);
577        let max_distinct = left_max_distinct.max(&right_max_distinct);
578        if max_distinct.get_value().is_some() {
579            // Seems like there are a few implementations of this algorithm that implement
580            // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs
581            // further exploration.
582            join_selectivity = max_distinct;
583        }
584    }
585
586    // With the assumption that the smaller input's domain is generally represented in the bigger
587    // input's domain, we can estimate the inner join's cardinality by taking the cartesian product
588    // of the two inputs and normalizing it by the selectivity factor.
589    let left_num_rows = left_stats.num_rows.get_value()?;
590    let right_num_rows = right_stats.num_rows.get_value()?;
591    match join_selectivity {
592        Precision::Exact(value) if value > 0 => {
593            Some(Precision::Exact((left_num_rows * right_num_rows) / value))
594        }
595        Precision::Inexact(value) if value > 0 => {
596            Some(Precision::Inexact((left_num_rows * right_num_rows) / value))
597        }
598        // Since we don't have any information about the selectivity (which is derived
599        // from the number of distinct rows information) we can give up here for now.
600        // And let other passes handle this (otherwise we would need to produce an
601        // overestimation using just the cartesian product).
602        _ => None,
603    }
604}
605
606/// Estimates if inputs are non-overlapping, using input statistics.
607/// If inputs are disjoint, returns zero estimation, otherwise returns None
608fn estimate_disjoint_inputs(
609    left_stats: &Statistics,
610    right_stats: &Statistics,
611) -> Option<Precision<usize>> {
612    for (left_stat, right_stat) in left_stats
613        .column_statistics
614        .iter()
615        .zip(right_stats.column_statistics.iter())
616    {
617        // If there is no overlap in any of the join columns, this means the join
618        // itself is disjoint and the cardinality is 0. Though we can only assume
619        // this when the statistics are exact (since it is a very strong assumption).
620        let left_min_val = left_stat.min_value.get_value();
621        let right_max_val = right_stat.max_value.get_value();
622        if left_min_val.is_some()
623            && right_max_val.is_some()
624            && left_min_val > right_max_val
625        {
626            return Some(
627                if left_stat.min_value.is_exact().unwrap_or(false)
628                    && right_stat.max_value.is_exact().unwrap_or(false)
629                {
630                    Precision::Exact(0)
631                } else {
632                    Precision::Inexact(0)
633                },
634            );
635        }
636
637        let left_max_val = left_stat.max_value.get_value();
638        let right_min_val = right_stat.min_value.get_value();
639        if left_max_val.is_some()
640            && right_min_val.is_some()
641            && left_max_val < right_min_val
642        {
643            return Some(
644                if left_stat.max_value.is_exact().unwrap_or(false)
645                    && right_stat.min_value.is_exact().unwrap_or(false)
646                {
647                    Precision::Exact(0)
648                } else {
649                    Precision::Inexact(0)
650                },
651            );
652        }
653    }
654
655    None
656}
657
658/// Estimate the number of maximum distinct values that can be present in the
659/// given column from its statistics. If distinct_count is available, uses it
660/// directly. Otherwise, if the column is numeric and has min/max values, it
661/// estimates the maximum distinct count from those.
662fn max_distinct_count(
663    num_rows: &Precision<usize>,
664    stats: &ColumnStatistics,
665) -> Precision<usize> {
666    match &stats.distinct_count {
667        &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc,
668        _ => {
669            // The number can never be greater than the number of rows we have
670            // minus the nulls (since they don't count as distinct values).
671            let result = match num_rows {
672                Precision::Absent => Precision::Absent,
673                Precision::Inexact(count) => {
674                    // To safeguard against inexact number of rows (e.g. 0) being smaller than
675                    // an exact null count we need to do a checked subtraction.
676                    match count.checked_sub(*stats.null_count.get_value().unwrap_or(&0)) {
677                        None => Precision::Inexact(0),
678                        Some(non_null_count) => Precision::Inexact(non_null_count),
679                    }
680                }
681                Precision::Exact(count) => {
682                    let count = count - stats.null_count.get_value().unwrap_or(&0);
683                    if stats.null_count.is_exact().unwrap_or(false) {
684                        Precision::Exact(count)
685                    } else {
686                        Precision::Inexact(count)
687                    }
688                }
689            };
690            // Cap the estimate using the number of possible values:
691            if let (Some(min), Some(max)) =
692                (stats.min_value.get_value(), stats.max_value.get_value())
693            {
694                if let Some(range_dc) = Interval::try_new(min.clone(), max.clone())
695                    .ok()
696                    .and_then(|e| e.cardinality())
697                {
698                    let range_dc = range_dc as usize;
699                    // Note that the `unwrap` calls in the below statement are safe.
700                    return if matches!(result, Precision::Absent)
701                        || &range_dc < result.get_value().unwrap()
702                    {
703                        if stats.min_value.is_exact().unwrap()
704                            && stats.max_value.is_exact().unwrap()
705                        {
706                            Precision::Exact(range_dc)
707                        } else {
708                            Precision::Inexact(range_dc)
709                        }
710                    } else {
711                        result
712                    };
713                }
714            }
715
716            result
717        }
718    }
719}
720
721enum OnceFutState<T> {
722    Pending(OnceFutPending<T>),
723    Ready(SharedResult<Arc<T>>),
724}
725
726impl<T> Clone for OnceFutState<T> {
727    fn clone(&self) -> Self {
728        match self {
729            Self::Pending(p) => Self::Pending(p.clone()),
730            Self::Ready(r) => Self::Ready(r.clone()),
731        }
732    }
733}
734
735impl<T: 'static> OnceFut<T> {
736    /// Create a new [`OnceFut`] from a [`Future`]
737    pub(crate) fn new<Fut>(fut: Fut) -> Self
738    where
739        Fut: Future<Output = Result<T>> + Send + 'static,
740    {
741        Self {
742            state: OnceFutState::Pending(
743                fut.map(|res| res.map(Arc::new).map_err(Arc::new))
744                    .boxed()
745                    .shared(),
746            ),
747        }
748    }
749
750    /// Get the result of the computation if it is ready, without consuming it
751    pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll<Result<&T>> {
752        if let OnceFutState::Pending(fut) = &mut self.state {
753            let r = ready!(fut.poll_unpin(cx));
754            self.state = OnceFutState::Ready(r);
755        }
756
757        // Cannot use loop as this would trip up the borrow checker
758        match &self.state {
759            OnceFutState::Pending(_) => unreachable!(),
760            OnceFutState::Ready(r) => Poll::Ready(
761                r.as_ref()
762                    .map(|r| r.as_ref())
763                    .map_err(DataFusionError::from),
764            ),
765        }
766    }
767
768    /// Get shared reference to the result of the computation if it is ready, without consuming it
769    pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll<Result<Arc<T>>> {
770        if let OnceFutState::Pending(fut) = &mut self.state {
771            let r = ready!(fut.poll_unpin(cx));
772            self.state = OnceFutState::Ready(r);
773        }
774
775        match &self.state {
776            OnceFutState::Pending(_) => unreachable!(),
777            OnceFutState::Ready(r) => {
778                Poll::Ready(r.clone().map_err(DataFusionError::Shared))
779            }
780        }
781    }
782}
783
784/// Should we use a bitmap to track each incoming right batch's each row's
785/// 'joined' status.
786///
787/// For example in right joins, we have to use a bit map to track matched
788/// right side rows, and later enter a `EmitRightUnmatched` stage to emit
789/// unmatched right rows.
790pub(crate) fn need_produce_right_in_final(join_type: JoinType) -> bool {
791    matches!(
792        join_type,
793        JoinType::Full
794            | JoinType::Right
795            | JoinType::RightAnti
796            | JoinType::RightMark
797            | JoinType::RightSemi
798    )
799}
800
801/// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and
802/// use the bit map to generate the part of result of the join.
803///
804/// For example of the `Left` join, in each iteration of right side, can get the matched result, but need
805/// to maintain the matched indices bit map to get the unmatched row for the left side.
806pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
807    matches!(
808        join_type,
809        JoinType::Left
810            | JoinType::LeftAnti
811            | JoinType::LeftSemi
812            | JoinType::LeftMark
813            | JoinType::Full
814    )
815}
816
817pub(crate) fn get_final_indices_from_shared_bitmap(
818    shared_bitmap: &SharedBitmapBuilder,
819    join_type: JoinType,
820) -> (UInt64Array, UInt32Array) {
821    let bitmap = shared_bitmap.lock();
822    get_final_indices_from_bit_map(&bitmap, join_type)
823}
824
825/// In the end of join execution, need to use bit map of the matched
826/// indices to generate the final left and right indices.
827///
828/// For example:
829///
830/// 1. left_bit_map: `[true, false, true, true, false]`
831/// 2. join_type: `Left`
832///
833/// The result is: `([1,4], [null, null])`
834pub(crate) fn get_final_indices_from_bit_map(
835    left_bit_map: &BooleanBufferBuilder,
836    join_type: JoinType,
837) -> (UInt64Array, UInt32Array) {
838    let left_size = left_bit_map.len();
839    if join_type == JoinType::LeftMark {
840        let left_indices = (0..left_size as u64).collect::<UInt64Array>();
841        let right_indices = (0..left_size)
842            .map(|idx| left_bit_map.get_bit(idx).then_some(0))
843            .collect::<UInt32Array>();
844        return (left_indices, right_indices);
845    }
846    let left_indices = if join_type == JoinType::LeftSemi {
847        (0..left_size)
848            .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
849            .collect::<UInt64Array>()
850    } else {
851        // just for `Left`, `LeftAnti` and `Full` join
852        // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally
853        (0..left_size)
854            .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
855            .collect::<UInt64Array>()
856    };
857    // right_indices
858    // all the element in the right side is None
859    let mut builder = UInt32Builder::with_capacity(left_indices.len());
860    builder.append_nulls(left_indices.len());
861    let right_indices = builder.finish();
862    (left_indices, right_indices)
863}
864
865pub(crate) fn apply_join_filter_to_indices(
866    build_input_buffer: &RecordBatch,
867    probe_batch: &RecordBatch,
868    build_indices: UInt64Array,
869    probe_indices: UInt32Array,
870    filter: &JoinFilter,
871    build_side: JoinSide,
872    max_intermediate_size: Option<usize>,
873) -> Result<(UInt64Array, UInt32Array)> {
874    if build_indices.is_empty() && probe_indices.is_empty() {
875        return Ok((build_indices, probe_indices));
876    };
877
878    let filter_result = if let Some(max_size) = max_intermediate_size {
879        let mut filter_results =
880            Vec::with_capacity(build_indices.len().div_ceil(max_size));
881
882        for i in (0..build_indices.len()).step_by(max_size) {
883            let end = min(build_indices.len(), i + max_size);
884            let len = end - i;
885            let intermediate_batch = build_batch_from_indices(
886                filter.schema(),
887                build_input_buffer,
888                probe_batch,
889                &build_indices.slice(i, len),
890                &probe_indices.slice(i, len),
891                filter.column_indices(),
892                build_side,
893            )?;
894            let filter_result = filter
895                .expression()
896                .evaluate(&intermediate_batch)?
897                .into_array(intermediate_batch.num_rows())?;
898            filter_results.push(filter_result);
899        }
900
901        let filter_refs: Vec<&dyn Array> =
902            filter_results.iter().map(|a| a.as_ref()).collect();
903
904        compute::concat(&filter_refs)?
905    } else {
906        let intermediate_batch = build_batch_from_indices(
907            filter.schema(),
908            build_input_buffer,
909            probe_batch,
910            &build_indices,
911            &probe_indices,
912            filter.column_indices(),
913            build_side,
914        )?;
915
916        filter
917            .expression()
918            .evaluate(&intermediate_batch)?
919            .into_array(intermediate_batch.num_rows())?
920    };
921
922    let mask = as_boolean_array(&filter_result)?;
923
924    let left_filtered = compute::filter(&build_indices, mask)?;
925    let right_filtered = compute::filter(&probe_indices, mask)?;
926    Ok((
927        downcast_array(left_filtered.as_ref()),
928        downcast_array(right_filtered.as_ref()),
929    ))
930}
931
932/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`.
933/// The resulting batch has [Schema] `schema`.
934pub(crate) fn build_batch_from_indices(
935    schema: &Schema,
936    build_input_buffer: &RecordBatch,
937    probe_batch: &RecordBatch,
938    build_indices: &UInt64Array,
939    probe_indices: &UInt32Array,
940    column_indices: &[ColumnIndex],
941    build_side: JoinSide,
942) -> Result<RecordBatch> {
943    if schema.fields().is_empty() {
944        let options = RecordBatchOptions::new()
945            .with_match_field_names(true)
946            .with_row_count(Some(build_indices.len()));
947
948        return Ok(RecordBatch::try_new_with_options(
949            Arc::new(schema.clone()),
950            vec![],
951            &options,
952        )?);
953    }
954
955    // build the columns of the new [RecordBatch]:
956    // 1. pick whether the column is from the left or right
957    // 2. based on the pick, `take` items from the different RecordBatches
958    let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
959
960    for column_index in column_indices {
961        let array = if column_index.side == JoinSide::None {
962            // For mark joins, the mark column is a true if the indices is not null, otherwise it will be false
963            Arc::new(compute::is_not_null(probe_indices)?)
964        } else if column_index.side == build_side {
965            let array = build_input_buffer.column(column_index.index);
966            if array.is_empty() || build_indices.null_count() == build_indices.len() {
967                // Outer join would generate a null index when finding no match at our side.
968                // Therefore, it's possible we are empty but need to populate an n-length null array,
969                // where n is the length of the index array.
970                assert_eq!(build_indices.null_count(), build_indices.len());
971                new_null_array(array.data_type(), build_indices.len())
972            } else {
973                take(array.as_ref(), build_indices, None)?
974            }
975        } else {
976            let array = probe_batch.column(column_index.index);
977            if array.is_empty() || probe_indices.null_count() == probe_indices.len() {
978                assert_eq!(probe_indices.null_count(), probe_indices.len());
979                new_null_array(array.data_type(), probe_indices.len())
980            } else {
981                take(array.as_ref(), probe_indices, None)?
982            }
983        };
984
985        columns.push(array);
986    }
987    Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
988}
989
990/// Returns a new [RecordBatch] resulting of a join where the build/left side is empty.
991/// The resulting batch has [Schema] `schema`.
992pub(crate) fn build_batch_empty_build_side(
993    schema: &Schema,
994    build_batch: &RecordBatch,
995    probe_batch: &RecordBatch,
996    column_indices: &[ColumnIndex],
997    join_type: JoinType,
998) -> Result<RecordBatch> {
999    match join_type {
1000        // these join types only return data if the left side is not empty, so we return an
1001        // empty RecordBatch
1002        JoinType::Inner
1003        | JoinType::Left
1004        | JoinType::LeftSemi
1005        | JoinType::RightSemi
1006        | JoinType::LeftAnti
1007        | JoinType::LeftMark => Ok(RecordBatch::new_empty(Arc::new(schema.clone()))),
1008
1009        // the remaining joins will return data for the right columns and null for the left ones
1010        JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => {
1011            let num_rows = probe_batch.num_rows();
1012            let mut columns: Vec<Arc<dyn Array>> =
1013                Vec::with_capacity(schema.fields().len());
1014
1015            for column_index in column_indices {
1016                let array = match column_index.side {
1017                    // left -> null array
1018                    JoinSide::Left => new_null_array(
1019                        build_batch.column(column_index.index).data_type(),
1020                        num_rows,
1021                    ),
1022                    // right -> respective right array
1023                    JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)),
1024                    // right mark -> unset boolean array as there are no matches on the left side
1025                    JoinSide::None => Arc::new(BooleanArray::new(
1026                        BooleanBuffer::new_unset(num_rows),
1027                        None,
1028                    )),
1029                };
1030
1031                columns.push(array);
1032            }
1033
1034            Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
1035        }
1036    }
1037}
1038
1039/// The input is the matched indices for left and right and
1040/// adjust the indices according to the join type
1041pub(crate) fn adjust_indices_by_join_type(
1042    left_indices: UInt64Array,
1043    right_indices: UInt32Array,
1044    adjust_range: Range<usize>,
1045    join_type: JoinType,
1046    preserve_order_for_right: bool,
1047) -> Result<(UInt64Array, UInt32Array)> {
1048    match join_type {
1049        JoinType::Inner => {
1050            // matched
1051            Ok((left_indices, right_indices))
1052        }
1053        JoinType::Left => {
1054            // matched
1055            Ok((left_indices, right_indices))
1056            // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap
1057        }
1058        JoinType::Right => {
1059            // combine the matched and unmatched right result together
1060            append_right_indices(
1061                left_indices,
1062                right_indices,
1063                adjust_range,
1064                preserve_order_for_right,
1065            )
1066        }
1067        JoinType::Full => {
1068            append_right_indices(left_indices, right_indices, adjust_range, false)
1069        }
1070        JoinType::RightSemi => {
1071            // need to remove the duplicated record in the right side
1072            let right_indices = get_semi_indices(adjust_range, &right_indices);
1073            // the left_indices will not be used later for the `right semi` join
1074            Ok((left_indices, right_indices))
1075        }
1076        JoinType::RightAnti => {
1077            // need to remove the duplicated record in the right side
1078            // get the anti index for the right side
1079            let right_indices = get_anti_indices(adjust_range, &right_indices);
1080            // the left_indices will not be used later for the `right anti` join
1081            Ok((left_indices, right_indices))
1082        }
1083        JoinType::RightMark => {
1084            let right_indices = get_mark_indices(&adjust_range, &right_indices);
1085            let left_indices_vec: Vec<u64> = adjust_range.map(|i| i as u64).collect();
1086            let left_indices = UInt64Array::from(left_indices_vec);
1087            Ok((left_indices, right_indices))
1088        }
1089        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1090            // matched or unmatched left row will be produced in the end of loop
1091            // When visit the right batch, we can output the matched left row and don't need to wait the end of loop
1092            Ok((
1093                UInt64Array::from_iter_values(vec![]),
1094                UInt32Array::from_iter_values(vec![]),
1095            ))
1096        }
1097    }
1098}
1099
1100/// Appends right indices to left indices based on the specified order mode.
1101///
1102/// The function operates in two modes:
1103/// 1. If `preserve_order_for_right` is true, probe matched and unmatched indices
1104///    are inserted in order using the `append_probe_indices_in_order()` method.
1105/// 2. Otherwise, unmatched probe indices are simply appended after matched ones.
1106///
1107/// # Parameters
1108/// - `left_indices`: UInt64Array of left indices.
1109/// - `right_indices`: UInt32Array of right indices.
1110/// - `adjust_range`: Range to adjust the right indices.
1111/// - `preserve_order_for_right`: Boolean flag to determine the mode of operation.
1112///
1113/// # Returns
1114/// A tuple of updated `UInt64Array` and `UInt32Array`.
1115pub(crate) fn append_right_indices(
1116    left_indices: UInt64Array,
1117    right_indices: UInt32Array,
1118    adjust_range: Range<usize>,
1119    preserve_order_for_right: bool,
1120) -> Result<(UInt64Array, UInt32Array)> {
1121    if preserve_order_for_right {
1122        Ok(append_probe_indices_in_order(
1123            left_indices,
1124            right_indices,
1125            adjust_range,
1126        ))
1127    } else {
1128        let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices);
1129
1130        if right_unmatched_indices.is_empty() {
1131            Ok((left_indices, right_indices))
1132        } else {
1133            // `into_builder()` can fail here when there is nothing to be filtered and
1134            // left_indices or right_indices has the same reference to the cached indices.
1135            // In that case, we use a slower alternative.
1136
1137            // the new left indices: left_indices + null array
1138            let mut new_left_indices_builder =
1139                left_indices.into_builder().unwrap_or_else(|left_indices| {
1140                    let mut builder = UInt64Builder::with_capacity(
1141                        left_indices.len() + right_unmatched_indices.len(),
1142                    );
1143                    debug_assert_eq!(
1144                        left_indices.null_count(),
1145                        0,
1146                        "expected left indices to have no nulls"
1147                    );
1148                    builder.append_slice(left_indices.values());
1149                    builder
1150                });
1151            new_left_indices_builder.append_nulls(right_unmatched_indices.len());
1152            let new_left_indices = UInt64Array::from(new_left_indices_builder.finish());
1153
1154            // the new right indices: right_indices + right_unmatched_indices
1155            let mut new_right_indices_builder = right_indices
1156                .into_builder()
1157                .unwrap_or_else(|right_indices| {
1158                    let mut builder = UInt32Builder::with_capacity(
1159                        right_indices.len() + right_unmatched_indices.len(),
1160                    );
1161                    debug_assert_eq!(
1162                        right_indices.null_count(),
1163                        0,
1164                        "expected right indices to have no nulls"
1165                    );
1166                    builder.append_slice(right_indices.values());
1167                    builder
1168                });
1169            debug_assert_eq!(
1170                right_unmatched_indices.null_count(),
1171                0,
1172                "expected right unmatched indices to have no nulls"
1173            );
1174            new_right_indices_builder.append_slice(right_unmatched_indices.values());
1175            let new_right_indices = UInt32Array::from(new_right_indices_builder.finish());
1176
1177            Ok((new_left_indices, new_right_indices))
1178        }
1179    }
1180}
1181
1182/// Returns `range` indices which are not present in `input_indices`
1183pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
1184    range: Range<usize>,
1185    input_indices: &PrimitiveArray<T>,
1186) -> PrimitiveArray<T>
1187where
1188    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1189{
1190    let bitmap = build_range_bitmap(&range, input_indices);
1191    let offset = range.start;
1192
1193    // get the anti index
1194    (range)
1195        .filter_map(|idx| {
1196            (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1197        })
1198        .collect()
1199}
1200
1201/// Returns intersection of `range` and `input_indices` omitting duplicates
1202pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
1203    range: Range<usize>,
1204    input_indices: &PrimitiveArray<T>,
1205) -> PrimitiveArray<T>
1206where
1207    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1208{
1209    let bitmap = build_range_bitmap(&range, input_indices);
1210    let offset = range.start;
1211    // get the semi index
1212    (range)
1213        .filter_map(|idx| {
1214            (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1215        })
1216        .collect()
1217}
1218
1219pub(crate) fn get_mark_indices<T: ArrowPrimitiveType>(
1220    range: &Range<usize>,
1221    input_indices: &PrimitiveArray<T>,
1222) -> PrimitiveArray<UInt32Type>
1223where
1224    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1225{
1226    let mut bitmap = build_range_bitmap(range, input_indices);
1227    PrimitiveArray::new(
1228        vec![0; range.len()].into(),
1229        Some(NullBuffer::new(bitmap.finish())),
1230    )
1231}
1232
1233fn build_range_bitmap<T: ArrowPrimitiveType>(
1234    range: &Range<usize>,
1235    input: &PrimitiveArray<T>,
1236) -> BooleanBufferBuilder {
1237    let mut builder = BooleanBufferBuilder::new(range.len());
1238    builder.append_n(range.len(), false);
1239
1240    input.iter().flatten().for_each(|v| {
1241        let idx = v.as_usize();
1242        if range.contains(&idx) {
1243            builder.set_bit(idx - range.start, true);
1244        }
1245    });
1246
1247    builder
1248}
1249
1250/// Appends probe indices in order by considering the given build indices.
1251///
1252/// This function constructs new build and probe indices by iterating through
1253/// the provided indices, and appends any missing values between previous and
1254/// current probe index with a corresponding null build index.
1255///
1256/// # Parameters
1257///
1258/// - `build_indices`: `PrimitiveArray` of `UInt64Type` containing build indices.
1259/// - `probe_indices`: `PrimitiveArray` of `UInt32Type` containing probe indices.
1260/// - `range`: The range of indices to consider.
1261///
1262/// # Returns
1263///
1264/// A tuple of two arrays:
1265/// - A `PrimitiveArray` of `UInt64Type` with the newly constructed build indices.
1266/// - A `PrimitiveArray` of `UInt32Type` with the newly constructed probe indices.
1267fn append_probe_indices_in_order(
1268    build_indices: PrimitiveArray<UInt64Type>,
1269    probe_indices: PrimitiveArray<UInt32Type>,
1270    range: Range<usize>,
1271) -> (PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>) {
1272    // Builders for new indices:
1273    let mut new_build_indices = UInt64Builder::new();
1274    let mut new_probe_indices = UInt32Builder::new();
1275    // Set previous index as the start index for the initial loop:
1276    let mut prev_index = range.start as u32;
1277    // Zip the two iterators.
1278    debug_assert!(build_indices.len() == probe_indices.len());
1279    for (build_index, probe_index) in build_indices
1280        .values()
1281        .into_iter()
1282        .zip(probe_indices.values().into_iter())
1283    {
1284        // Append values between previous and current probe index with null build index:
1285        for value in prev_index..*probe_index {
1286            new_probe_indices.append_value(value);
1287            new_build_indices.append_null();
1288        }
1289        // Append current indices:
1290        new_probe_indices.append_value(*probe_index);
1291        new_build_indices.append_value(*build_index);
1292        // Set current probe index as previous for the next iteration:
1293        prev_index = probe_index + 1;
1294    }
1295    // Append remaining probe indices after the last valid probe index with null build index.
1296    for value in prev_index..range.end as u32 {
1297        new_probe_indices.append_value(value);
1298        new_build_indices.append_null();
1299    }
1300    // Build arrays and return:
1301    (new_build_indices.finish(), new_probe_indices.finish())
1302}
1303
1304/// Metrics for build & probe joins
1305#[derive(Clone, Debug)]
1306pub(crate) struct BuildProbeJoinMetrics {
1307    pub(crate) baseline: BaselineMetrics,
1308    /// Total time for collecting build-side of join
1309    pub(crate) build_time: metrics::Time,
1310    /// Number of batches consumed by build-side
1311    pub(crate) build_input_batches: metrics::Count,
1312    /// Number of rows consumed by build-side
1313    pub(crate) build_input_rows: metrics::Count,
1314    /// Memory used by build-side in bytes
1315    pub(crate) build_mem_used: metrics::Gauge,
1316    /// Total time for joining probe-side batches to the build-side batches
1317    pub(crate) join_time: metrics::Time,
1318    /// Number of batches consumed by probe-side of this operator
1319    pub(crate) input_batches: metrics::Count,
1320    /// Number of rows consumed by probe-side this operator
1321    pub(crate) input_rows: metrics::Count,
1322    /// Number of batches produced by this operator
1323    pub(crate) output_batches: metrics::Count,
1324}
1325
1326// This Drop implementation updates the elapsed compute part of the metrics.
1327//
1328// Why is this in a Drop?
1329// - We keep track of build_time and join_time separately, but baseline metrics have
1330// a total elapsed_compute time. Instead of remembering to update both the metrics
1331// at the same time, we chose to update elapsed_compute once at the end - summing up
1332// both the parts.
1333//
1334// How does this work?
1335// - The elapsed_compute `Time` is represented by an `Arc<AtomicUsize>`. So even when
1336// this `BuildProbeJoinMetrics` is dropped, the elapsed_compute is usable through the
1337// Arc reference.
1338impl Drop for BuildProbeJoinMetrics {
1339    fn drop(&mut self) {
1340        self.baseline.elapsed_compute().add(&self.build_time);
1341        self.baseline.elapsed_compute().add(&self.join_time);
1342    }
1343}
1344
1345impl BuildProbeJoinMetrics {
1346    pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
1347        let baseline = BaselineMetrics::new(metrics, partition);
1348
1349        let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
1350
1351        let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition);
1352
1353        let build_input_batches =
1354            MetricBuilder::new(metrics).counter("build_input_batches", partition);
1355
1356        let build_input_rows =
1357            MetricBuilder::new(metrics).counter("build_input_rows", partition);
1358
1359        let build_mem_used =
1360            MetricBuilder::new(metrics).gauge("build_mem_used", partition);
1361
1362        let input_batches =
1363            MetricBuilder::new(metrics).counter("input_batches", partition);
1364
1365        let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
1366
1367        let output_batches =
1368            MetricBuilder::new(metrics).counter("output_batches", partition);
1369
1370        Self {
1371            build_time,
1372            build_input_batches,
1373            build_input_rows,
1374            build_mem_used,
1375            join_time,
1376            input_batches,
1377            input_rows,
1378            output_batches,
1379            baseline,
1380        }
1381    }
1382}
1383
1384/// The `handle_state` macro is designed to process the result of a state-changing
1385/// operation. It operates on a `StatefulStreamResult` by matching its variants and
1386/// executing corresponding actions. This macro is used to streamline code that deals
1387/// with state transitions, reducing boilerplate and improving readability.
1388///
1389/// # Cases
1390///
1391/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the
1392///   stream join operation should proceed to the next step.
1393/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the
1394///   result, either yielding a value or indicating the stream is awaiting more
1395///   data.
1396/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
1397///   during the stream join operation.
1398///
1399/// # Arguments
1400///
1401/// * `$match_case`: An expression that evaluates to a `Result<StatefulStreamResult<_>>`.
1402#[macro_export]
1403macro_rules! handle_state {
1404    ($match_case:expr) => {
1405        match $match_case {
1406            Ok(StatefulStreamResult::Continue) => continue,
1407            Ok(StatefulStreamResult::Ready(result)) => {
1408                Poll::Ready(Ok(result).transpose())
1409            }
1410            Err(e) => Poll::Ready(Some(Err(e))),
1411        }
1412    };
1413}
1414
1415/// Represents the result of a stateful operation.
1416///
1417/// This enumeration indicates whether the state produced a result that is
1418/// ready for use (`Ready`) or if the operation requires continuation (`Continue`).
1419///
1420/// Variants:
1421/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`.
1422/// - `Continue`: Indicates that the operation is not yet complete and requires further
1423///   processing or more data. When this variant is returned, it typically means that the
1424///   current invocation of the state did not produce a final result, and the operation
1425///   should be invoked again later with more data and possibly with a different state.
1426pub enum StatefulStreamResult<T> {
1427    Ready(T),
1428    Continue,
1429}
1430
1431pub(crate) fn symmetric_join_output_partitioning(
1432    left: &Arc<dyn ExecutionPlan>,
1433    right: &Arc<dyn ExecutionPlan>,
1434    join_type: &JoinType,
1435) -> Result<Partitioning> {
1436    let left_columns_len = left.schema().fields.len();
1437    let left_partitioning = left.output_partitioning();
1438    let right_partitioning = right.output_partitioning();
1439    let result = match join_type {
1440        JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
1441            left_partitioning.clone()
1442        }
1443        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1444            right_partitioning.clone()
1445        }
1446        JoinType::Inner | JoinType::Right => {
1447            adjust_right_output_partitioning(right_partitioning, left_columns_len)?
1448        }
1449        JoinType::Full => {
1450            // We could also use left partition count as they are necessarily equal.
1451            Partitioning::UnknownPartitioning(right_partitioning.partition_count())
1452        }
1453    };
1454    Ok(result)
1455}
1456
1457pub(crate) fn asymmetric_join_output_partitioning(
1458    left: &Arc<dyn ExecutionPlan>,
1459    right: &Arc<dyn ExecutionPlan>,
1460    join_type: &JoinType,
1461) -> Result<Partitioning> {
1462    let result = match join_type {
1463        JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
1464            right.output_partitioning(),
1465            left.schema().fields().len(),
1466        )?,
1467        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
1468            right.output_partitioning().clone()
1469        }
1470        JoinType::Left
1471        | JoinType::LeftSemi
1472        | JoinType::LeftAnti
1473        | JoinType::Full
1474        | JoinType::LeftMark => Partitioning::UnknownPartitioning(
1475            right.output_partitioning().partition_count(),
1476        ),
1477    };
1478    Ok(result)
1479}
1480
1481/// Trait for incrementally generating Join output.
1482///
1483/// This trait is used to limit some join outputs
1484/// so it does not produce single large batches
1485pub(crate) trait BatchTransformer: Debug + Clone {
1486    /// Sets the next `RecordBatch` to be processed.
1487    fn set_batch(&mut self, batch: RecordBatch);
1488
1489    /// Retrieves the next `RecordBatch` from the transformer.
1490    /// Returns `None` if all batches have been produced.
1491    /// The boolean flag indicates whether the batch is the last one.
1492    fn next(&mut self) -> Option<(RecordBatch, bool)>;
1493}
1494
1495#[derive(Debug, Clone)]
1496/// A batch transformer that does nothing.
1497pub(crate) struct NoopBatchTransformer {
1498    /// RecordBatch to be processed
1499    batch: Option<RecordBatch>,
1500}
1501
1502impl NoopBatchTransformer {
1503    pub fn new() -> Self {
1504        Self { batch: None }
1505    }
1506}
1507
1508impl BatchTransformer for NoopBatchTransformer {
1509    fn set_batch(&mut self, batch: RecordBatch) {
1510        self.batch = Some(batch);
1511    }
1512
1513    fn next(&mut self) -> Option<(RecordBatch, bool)> {
1514        self.batch.take().map(|batch| (batch, true))
1515    }
1516}
1517
1518#[derive(Debug, Clone)]
1519/// Splits large batches into smaller batches with a maximum number of rows.
1520pub(crate) struct BatchSplitter {
1521    /// RecordBatch to be split
1522    batch: Option<RecordBatch>,
1523    /// Maximum number of rows in a split batch
1524    batch_size: usize,
1525    /// Current row index
1526    row_index: usize,
1527}
1528
1529impl BatchSplitter {
1530    /// Creates a new `BatchSplitter` with the specified batch size.
1531    pub(crate) fn new(batch_size: usize) -> Self {
1532        Self {
1533            batch: None,
1534            batch_size,
1535            row_index: 0,
1536        }
1537    }
1538}
1539
1540impl BatchTransformer for BatchSplitter {
1541    fn set_batch(&mut self, batch: RecordBatch) {
1542        self.batch = Some(batch);
1543        self.row_index = 0;
1544    }
1545
1546    fn next(&mut self) -> Option<(RecordBatch, bool)> {
1547        let Some(batch) = &self.batch else {
1548            return None;
1549        };
1550
1551        let remaining_rows = batch.num_rows() - self.row_index;
1552        let rows_to_slice = remaining_rows.min(self.batch_size);
1553        let sliced_batch = batch.slice(self.row_index, rows_to_slice);
1554        self.row_index += rows_to_slice;
1555
1556        let mut last = false;
1557        if self.row_index >= batch.num_rows() {
1558            self.batch = None;
1559            last = true;
1560        }
1561
1562        Some((sliced_batch, last))
1563    }
1564}
1565
1566/// When the order of the join inputs are changed, the output order of columns
1567/// must remain the same.
1568///
1569/// Joins output columns from their left input followed by their right input.
1570/// Thus if the inputs are reordered, the output columns must be reordered to
1571/// match the original order.
1572pub fn reorder_output_after_swap(
1573    plan: Arc<dyn ExecutionPlan>,
1574    left_schema: &Schema,
1575    right_schema: &Schema,
1576) -> Result<Arc<dyn ExecutionPlan>> {
1577    let proj = ProjectionExec::try_new(
1578        swap_reverting_projection(left_schema, right_schema),
1579        plan,
1580    )?;
1581    Ok(Arc::new(proj))
1582}
1583
1584/// When the order of the join is changed, the output order of columns must
1585/// remain the same.
1586///
1587/// Returns the expressions that will allow to swap back the values from the
1588/// original left as the first columns and those on the right next.
1589fn swap_reverting_projection(
1590    left_schema: &Schema,
1591    right_schema: &Schema,
1592) -> Vec<ProjectionExpr> {
1593    let right_cols =
1594        right_schema
1595            .fields()
1596            .iter()
1597            .enumerate()
1598            .map(|(i, f)| ProjectionExpr {
1599                expr: Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
1600                alias: f.name().to_owned(),
1601            });
1602    let right_len = right_cols.len();
1603    let left_cols =
1604        left_schema
1605            .fields()
1606            .iter()
1607            .enumerate()
1608            .map(|(i, f)| ProjectionExpr {
1609                expr: Arc::new(Column::new(f.name(), right_len + i))
1610                    as Arc<dyn PhysicalExpr>,
1611                alias: f.name().to_owned(),
1612            });
1613
1614    left_cols.chain(right_cols).collect()
1615}
1616
1617/// This function swaps the given join's projection.
1618pub fn swap_join_projection(
1619    left_schema_len: usize,
1620    right_schema_len: usize,
1621    projection: Option<&Vec<usize>>,
1622    join_type: &JoinType,
1623) -> Option<Vec<usize>> {
1624    match join_type {
1625        // For Anti/Semi join types, projection should remain unmodified,
1626        // since these joins output schema remains the same after swap
1627        JoinType::LeftAnti
1628        | JoinType::LeftSemi
1629        | JoinType::RightAnti
1630        | JoinType::RightSemi => projection.cloned(),
1631
1632        _ => projection.map(|p| {
1633            p.iter()
1634                .map(|i| {
1635                    // If the index is less than the left schema length, it is from
1636                    // the left schema, so we add the right schema length to it.
1637                    // Otherwise, it is from the right schema, so we subtract the left
1638                    // schema length from it.
1639                    if *i < left_schema_len {
1640                        *i + right_schema_len
1641                    } else {
1642                        *i - left_schema_len
1643                    }
1644                })
1645                .collect()
1646        }),
1647    }
1648}
1649
1650/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on`
1651/// using `offset` as a start value for `batch` row indices.
1652///
1653/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap,
1654/// which allows to keep either first (if set to true) or last (if set to false) row index
1655/// as a chain head for rows with equal hash values.
1656#[allow(clippy::too_many_arguments)]
1657pub fn update_hash(
1658    on: &[PhysicalExprRef],
1659    batch: &RecordBatch,
1660    hash_map: &mut dyn JoinHashMapType,
1661    offset: usize,
1662    random_state: &RandomState,
1663    hashes_buffer: &mut Vec<u64>,
1664    deleted_offset: usize,
1665    fifo_hashmap: bool,
1666) -> Result<()> {
1667    // evaluate the keys
1668    let keys_values = on
1669        .iter()
1670        .map(|c| c.evaluate(batch)?.into_array(batch.num_rows()))
1671        .collect::<Result<Vec<_>>>()?;
1672
1673    // calculate the hash values
1674    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1675
1676    // For usual JoinHashmap, the implementation is void.
1677    hash_map.extend_zero(batch.num_rows());
1678
1679    // Updating JoinHashMap from hash values iterator
1680    let hash_values_iter = hash_values
1681        .iter()
1682        .enumerate()
1683        .map(|(i, val)| (i + offset, val));
1684
1685    if fifo_hashmap {
1686        hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset);
1687    } else {
1688        hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset);
1689    }
1690
1691    Ok(())
1692}
1693
1694pub(super) fn equal_rows_arr(
1695    indices_left: &UInt64Array,
1696    indices_right: &UInt32Array,
1697    left_arrays: &[ArrayRef],
1698    right_arrays: &[ArrayRef],
1699    null_equality: NullEquality,
1700) -> Result<(UInt64Array, UInt32Array)> {
1701    let mut iter = left_arrays.iter().zip(right_arrays.iter());
1702
1703    let Some((first_left, first_right)) = iter.next() else {
1704        return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into()));
1705    };
1706
1707    let arr_left = take(first_left.as_ref(), indices_left, None)?;
1708    let arr_right = take(first_right.as_ref(), indices_right, None)?;
1709
1710    let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?;
1711
1712    // Use map and try_fold to iterate over the remaining pairs of arrays.
1713    // In each iteration, take is used on the pair of arrays and their equality is determined.
1714    // The results are then folded (combined) using the and function to get a final equality result.
1715    equal = iter
1716        .map(|(left, right)| {
1717            let arr_left = take(left.as_ref(), indices_left, None)?;
1718            let arr_right = take(right.as_ref(), indices_right, None)?;
1719            eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality)
1720        })
1721        .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
1722
1723    let filter_builder = FilterBuilder::new(&equal).optimize().build();
1724
1725    let left_filtered = filter_builder.filter(indices_left)?;
1726    let right_filtered = filter_builder.filter(indices_right)?;
1727
1728    Ok((
1729        downcast_array(left_filtered.as_ref()),
1730        downcast_array(right_filtered.as_ref()),
1731    ))
1732}
1733
1734// version of eq_dyn supporting equality on null arrays
1735fn eq_dyn_null(
1736    left: &dyn Array,
1737    right: &dyn Array,
1738    null_equality: NullEquality,
1739) -> Result<BooleanArray, ArrowError> {
1740    // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special
1741    // implementation
1742    // <https://github.com/apache/datafusion/issues/10749>
1743    if left.data_type().is_nested() {
1744        let op = match null_equality {
1745            NullEquality::NullEqualsNothing => Operator::Eq,
1746            NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
1747        };
1748        return Ok(compare_op_for_nested(op, &left, &right)?);
1749    }
1750    match null_equality {
1751        NullEquality::NullEqualsNothing => eq(&left, &right),
1752        NullEquality::NullEqualsNull => not_distinct(&left, &right),
1753    }
1754}
1755
1756#[cfg(test)]
1757mod tests {
1758    use std::collections::HashMap;
1759    use std::pin::Pin;
1760
1761    use super::*;
1762
1763    use arrow::array::Int32Array;
1764    use arrow::datatypes::{DataType, Fields};
1765    use arrow::error::{ArrowError, Result as ArrowResult};
1766    use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
1767    use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
1768    use datafusion_physical_expr::PhysicalSortExpr;
1769
1770    use rstest::rstest;
1771
1772    fn check(
1773        left: &[Column],
1774        right: &[Column],
1775        on: &[(PhysicalExprRef, PhysicalExprRef)],
1776    ) -> Result<()> {
1777        let left = left
1778            .iter()
1779            .map(|x| x.to_owned())
1780            .collect::<HashSet<Column>>();
1781        let right = right
1782            .iter()
1783            .map(|x| x.to_owned())
1784            .collect::<HashSet<Column>>();
1785        check_join_set_is_valid(&left, &right, on)
1786    }
1787
1788    #[test]
1789    fn check_valid() -> Result<()> {
1790        let left = vec![Column::new("a", 0), Column::new("b1", 1)];
1791        let right = vec![Column::new("a", 0), Column::new("b2", 1)];
1792        let on = &[(
1793            Arc::new(Column::new("a", 0)) as _,
1794            Arc::new(Column::new("a", 0)) as _,
1795        )];
1796
1797        check(&left, &right, on)?;
1798        Ok(())
1799    }
1800
1801    #[test]
1802    fn check_not_in_right() {
1803        let left = vec![Column::new("a", 0), Column::new("b", 1)];
1804        let right = vec![Column::new("b", 0)];
1805        let on = &[(
1806            Arc::new(Column::new("a", 0)) as _,
1807            Arc::new(Column::new("a", 0)) as _,
1808        )];
1809
1810        assert!(check(&left, &right, on).is_err());
1811    }
1812
1813    #[tokio::test]
1814    async fn check_error_nesting() {
1815        let once_fut = OnceFut::<()>::new(async {
1816            arrow_err!(ArrowError::CsvError("some error".to_string()))
1817        });
1818
1819        struct TestFut(OnceFut<()>);
1820        impl Future for TestFut {
1821            type Output = ArrowResult<()>;
1822
1823            fn poll(
1824                mut self: Pin<&mut Self>,
1825                cx: &mut Context<'_>,
1826            ) -> Poll<Self::Output> {
1827                match ready!(self.0.get(cx)) {
1828                    Ok(()) => Poll::Ready(Ok(())),
1829                    Err(e) => Poll::Ready(Err(e.into())),
1830                }
1831            }
1832        }
1833
1834        let res = TestFut(once_fut).await;
1835        let arrow_err_from_fut = res.expect_err("once_fut always return error");
1836
1837        let wrapped_err = DataFusionError::from(arrow_err_from_fut);
1838        let root_err = wrapped_err.find_root();
1839
1840        let _expected =
1841            arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned()));
1842
1843        assert!(matches!(root_err, _expected))
1844    }
1845
1846    #[test]
1847    fn check_not_in_left() {
1848        let left = vec![Column::new("b", 0)];
1849        let right = vec![Column::new("a", 0)];
1850        let on = &[(
1851            Arc::new(Column::new("a", 0)) as _,
1852            Arc::new(Column::new("a", 0)) as _,
1853        )];
1854
1855        assert!(check(&left, &right, on).is_err());
1856    }
1857
1858    #[test]
1859    fn check_collision() {
1860        // column "a" would appear both in left and right
1861        let left = vec![Column::new("a", 0), Column::new("c", 1)];
1862        let right = vec![Column::new("a", 0), Column::new("b", 1)];
1863        let on = &[(
1864            Arc::new(Column::new("a", 0)) as _,
1865            Arc::new(Column::new("b", 1)) as _,
1866        )];
1867
1868        assert!(check(&left, &right, on).is_ok());
1869    }
1870
1871    #[test]
1872    fn check_in_right() {
1873        let left = vec![Column::new("a", 0), Column::new("c", 1)];
1874        let right = vec![Column::new("b", 0)];
1875        let on = &[(
1876            Arc::new(Column::new("a", 0)) as _,
1877            Arc::new(Column::new("b", 0)) as _,
1878        )];
1879
1880        assert!(check(&left, &right, on).is_ok());
1881    }
1882
1883    #[test]
1884    fn test_join_schema() -> Result<()> {
1885        let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1886        let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1887        let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
1888        let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
1889
1890        let cases = vec![
1891            (&a, &b, JoinType::Inner, &a, &b),
1892            (&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
1893            (&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
1894            (&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
1895            // right input of a `LEFT` join can be null, regardless of input nullness
1896            (&a, &b, JoinType::Left, &a, &b_nulls),
1897            (&a, &b_nulls, JoinType::Left, &a, &b_nulls),
1898            (&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
1899            (&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
1900            // left input of a `RIGHT` join can be null, regardless of input nullness
1901            (&a, &b, JoinType::Right, &a_nulls, &b),
1902            (&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
1903            (&a_nulls, &b, JoinType::Right, &a_nulls, &b),
1904            (&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
1905            // Either input of a `FULL` join can be null
1906            (&a, &b, JoinType::Full, &a_nulls, &b_nulls),
1907            (&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
1908            (&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
1909            (&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
1910        ];
1911
1912        for (left_in, right_in, join_type, left_out, right_out) in cases {
1913            let (schema, _) = build_join_schema(left_in, right_in, &join_type);
1914
1915            let expected_fields = left_out
1916                .fields()
1917                .iter()
1918                .cloned()
1919                .chain(right_out.fields().iter().cloned())
1920                .collect::<Fields>();
1921
1922            let expected_schema = Schema::new(expected_fields);
1923            assert_eq!(
1924                schema,
1925                expected_schema,
1926                "Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
1927                left_in.fields()[0].name(),
1928                left_in.fields()[0].is_nullable(),
1929                right_in.fields()[0].name(),
1930                right_in.fields()[0].is_nullable(),
1931                join_type
1932            );
1933        }
1934
1935        Ok(())
1936    }
1937
1938    fn create_stats(
1939        num_rows: Option<usize>,
1940        column_stats: Vec<ColumnStatistics>,
1941        is_exact: bool,
1942    ) -> Statistics {
1943        Statistics {
1944            num_rows: if is_exact {
1945                num_rows.map(Exact)
1946            } else {
1947                num_rows.map(Inexact)
1948            }
1949            .unwrap_or(Absent),
1950            column_statistics: column_stats,
1951            total_byte_size: Absent,
1952        }
1953    }
1954
1955    fn create_column_stats(
1956        min: Precision<i64>,
1957        max: Precision<i64>,
1958        distinct_count: Precision<usize>,
1959        null_count: Precision<usize>,
1960    ) -> ColumnStatistics {
1961        ColumnStatistics {
1962            distinct_count,
1963            min_value: min.map(ScalarValue::from),
1964            max_value: max.map(ScalarValue::from),
1965            sum_value: Absent,
1966            null_count,
1967        }
1968    }
1969
1970    type PartialStats = (
1971        usize,
1972        Precision<i64>,
1973        Precision<i64>,
1974        Precision<usize>,
1975        Precision<usize>,
1976    );
1977
1978    // This is mainly for validating the all edge cases of the estimation, but
1979    // more advanced (and real world test cases) are below where we need some control
1980    // over the expected output (since it depends on join type to join type).
1981    #[test]
1982    fn test_inner_join_cardinality_single_column() -> Result<()> {
1983        let cases: Vec<(PartialStats, PartialStats, Option<Precision<usize>>)> = vec![
1984            // ------------------------------------------------
1985            // | left(rows, min, max, distinct, null_count),  |
1986            // | right(rows, min, max, distinct, null_count), |
1987            // | expected,                                    |
1988            // ------------------------------------------------
1989
1990            // Cardinality computation
1991            // =======================
1992            //
1993            // distinct(left) == NaN, distinct(right) == NaN
1994            (
1995                (10, Inexact(1), Inexact(10), Absent, Absent),
1996                (10, Inexact(1), Inexact(10), Absent, Absent),
1997                Some(Inexact(10)),
1998            ),
1999            // range(left) > range(right)
2000            (
2001                (10, Inexact(6), Inexact(10), Absent, Absent),
2002                (10, Inexact(8), Inexact(10), Absent, Absent),
2003                Some(Inexact(20)),
2004            ),
2005            // range(right) > range(left)
2006            (
2007                (10, Inexact(8), Inexact(10), Absent, Absent),
2008                (10, Inexact(6), Inexact(10), Absent, Absent),
2009                Some(Inexact(20)),
2010            ),
2011            // range(left) > len(left), range(right) > len(right)
2012            (
2013                (10, Inexact(1), Inexact(15), Absent, Absent),
2014                (20, Inexact(1), Inexact(40), Absent, Absent),
2015                Some(Inexact(10)),
2016            ),
2017            // When we have distinct count.
2018            (
2019                (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2020                (10, Inexact(1), Inexact(10), Inexact(10), Absent),
2021                Some(Inexact(10)),
2022            ),
2023            // distinct(left) > distinct(right)
2024            (
2025                (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2026                (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2027                Some(Inexact(20)),
2028            ),
2029            // distinct(right) > distinct(left)
2030            (
2031                (10, Inexact(1), Inexact(10), Inexact(2), Absent),
2032                (10, Inexact(1), Inexact(10), Inexact(5), Absent),
2033                Some(Inexact(20)),
2034            ),
2035            // min(left) < 0 (range(left) > range(right))
2036            (
2037                (10, Inexact(-5), Inexact(5), Absent, Absent),
2038                (10, Inexact(1), Inexact(5), Absent, Absent),
2039                Some(Inexact(10)),
2040            ),
2041            // min(right) < 0, max(right) < 0 (range(right) > range(left))
2042            (
2043                (10, Inexact(-25), Inexact(-20), Absent, Absent),
2044                (10, Inexact(-25), Inexact(-15), Absent, Absent),
2045                Some(Inexact(10)),
2046            ),
2047            // range(left) < 0, range(right) >= 0
2048            // (there isn't a case where both left and right ranges are negative
2049            //  so one of them is always going to work, this just proves negative
2050            //  ranges with bigger absolute values are not are not accidentally used).
2051            (
2052                (10, Inexact(-10), Inexact(0), Absent, Absent),
2053                (10, Inexact(0), Inexact(10), Inexact(5), Absent),
2054                Some(Inexact(10)),
2055            ),
2056            // range(left) = 1, range(right) = 1
2057            (
2058                (10, Inexact(1), Inexact(1), Absent, Absent),
2059                (10, Inexact(1), Inexact(1), Absent, Absent),
2060                Some(Inexact(100)),
2061            ),
2062            //
2063            // Edge cases
2064            // ==========
2065            //
2066            // No column level stats.
2067            (
2068                (10, Absent, Absent, Absent, Absent),
2069                (10, Absent, Absent, Absent, Absent),
2070                None,
2071            ),
2072            // No min or max (or both).
2073            (
2074                (10, Absent, Absent, Inexact(3), Absent),
2075                (10, Absent, Absent, Inexact(3), Absent),
2076                None,
2077            ),
2078            (
2079                (10, Inexact(2), Absent, Inexact(3), Absent),
2080                (10, Absent, Inexact(5), Inexact(3), Absent),
2081                None,
2082            ),
2083            (
2084                (10, Absent, Inexact(3), Inexact(3), Absent),
2085                (10, Inexact(1), Absent, Inexact(3), Absent),
2086                None,
2087            ),
2088            (
2089                (10, Absent, Inexact(3), Absent, Absent),
2090                (10, Inexact(1), Absent, Absent, Absent),
2091                None,
2092            ),
2093            // Non overlapping min/max (when exact=False).
2094            (
2095                (10, Absent, Inexact(4), Absent, Absent),
2096                (10, Inexact(5), Absent, Absent, Absent),
2097                Some(Inexact(0)),
2098            ),
2099            (
2100                (10, Inexact(0), Inexact(10), Absent, Absent),
2101                (10, Inexact(11), Inexact(20), Absent, Absent),
2102                Some(Inexact(0)),
2103            ),
2104            (
2105                (10, Inexact(11), Inexact(20), Absent, Absent),
2106                (10, Inexact(0), Inexact(10), Absent, Absent),
2107                Some(Inexact(0)),
2108            ),
2109            // distinct(left) = 0, distinct(right) = 0
2110            (
2111                (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2112                (10, Inexact(1), Inexact(10), Inexact(0), Absent),
2113                None,
2114            ),
2115            // Inexact row count < exact null count with absent distinct count
2116            (
2117                (0, Inexact(1), Inexact(10), Absent, Exact(5)),
2118                (10, Inexact(1), Inexact(10), Absent, Absent),
2119                Some(Inexact(0)),
2120            ),
2121        ];
2122
2123        for (left_info, right_info, expected_cardinality) in cases {
2124            let left_num_rows = left_info.0;
2125            let left_col_stats = vec![create_column_stats(
2126                left_info.1,
2127                left_info.2,
2128                left_info.3,
2129                left_info.4,
2130            )];
2131
2132            let right_num_rows = right_info.0;
2133            let right_col_stats = vec![create_column_stats(
2134                right_info.1,
2135                right_info.2,
2136                right_info.3,
2137                right_info.4,
2138            )];
2139
2140            assert_eq!(
2141                estimate_inner_join_cardinality(
2142                    Statistics {
2143                        num_rows: Inexact(left_num_rows),
2144                        total_byte_size: Absent,
2145                        column_statistics: left_col_stats.clone(),
2146                    },
2147                    Statistics {
2148                        num_rows: Inexact(right_num_rows),
2149                        total_byte_size: Absent,
2150                        column_statistics: right_col_stats.clone(),
2151                    },
2152                ),
2153                expected_cardinality.clone()
2154            );
2155
2156            // We should also be able to use join_cardinality to get the same results
2157            let join_type = JoinType::Inner;
2158            let join_on = vec![(
2159                Arc::new(Column::new("a", 0)) as _,
2160                Arc::new(Column::new("b", 0)) as _,
2161            )];
2162            let partial_join_stats = estimate_join_cardinality(
2163                &join_type,
2164                create_stats(Some(left_num_rows), left_col_stats.clone(), false),
2165                create_stats(Some(right_num_rows), right_col_stats.clone(), false),
2166                &join_on,
2167            );
2168
2169            assert_eq!(
2170                partial_join_stats.clone().map(|s| Inexact(s.num_rows)),
2171                expected_cardinality.clone()
2172            );
2173            assert_eq!(
2174                partial_join_stats.map(|s| s.column_statistics),
2175                expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat())
2176            );
2177        }
2178        Ok(())
2179    }
2180
2181    #[test]
2182    fn test_inner_join_cardinality_multiple_column() -> Result<()> {
2183        let left_col_stats = vec![
2184            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2185            create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
2186        ];
2187
2188        let right_col_stats = vec![
2189            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2190            create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
2191        ];
2192
2193        // We have statistics about 4 columns, where the highest distinct
2194        // count is 200, so we are going to pick it.
2195        assert_eq!(
2196            estimate_inner_join_cardinality(
2197                Statistics {
2198                    num_rows: Inexact(400),
2199                    total_byte_size: Absent,
2200                    column_statistics: left_col_stats,
2201                },
2202                Statistics {
2203                    num_rows: Inexact(400),
2204                    total_byte_size: Absent,
2205                    column_statistics: right_col_stats,
2206                },
2207            ),
2208            Some(Inexact((400 * 400) / 200))
2209        );
2210        Ok(())
2211    }
2212
2213    #[test]
2214    fn test_inner_join_cardinality_decimal_range() -> Result<()> {
2215        let left_col_stats = vec![ColumnStatistics {
2216            distinct_count: Absent,
2217            min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)),
2218            max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)),
2219            ..Default::default()
2220        }];
2221
2222        let right_col_stats = vec![ColumnStatistics {
2223            distinct_count: Absent,
2224            min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)),
2225            max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)),
2226            ..Default::default()
2227        }];
2228
2229        assert_eq!(
2230            estimate_inner_join_cardinality(
2231                Statistics {
2232                    num_rows: Inexact(100),
2233                    total_byte_size: Absent,
2234                    column_statistics: left_col_stats,
2235                },
2236                Statistics {
2237                    num_rows: Inexact(100),
2238                    total_byte_size: Absent,
2239                    column_statistics: right_col_stats,
2240                },
2241            ),
2242            Some(Inexact(100))
2243        );
2244        Ok(())
2245    }
2246
2247    #[test]
2248    fn test_join_cardinality() -> Result<()> {
2249        // Left table (rows=1000)
2250        //   a: min=0, max=100, distinct=100
2251        //   b: min=0, max=500, distinct=500
2252        //   x: min=1000, max=10000, distinct=None
2253        //
2254        // Right table (rows=2000)
2255        //   c: min=0, max=100, distinct=50
2256        //   d: min=0, max=2000, distinct=2500 (how? some inexact statistics)
2257        //   y: min=0, max=100, distinct=None
2258        //
2259        // Join on a=c, b=d (ignore x/y)
2260        let cases = vec![
2261            (JoinType::Inner, 800),
2262            (JoinType::Left, 1000),
2263            (JoinType::Right, 2000),
2264            (JoinType::Full, 2200),
2265        ];
2266
2267        let left_col_stats = vec![
2268            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2269            create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2270            create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2271        ];
2272
2273        let right_col_stats = vec![
2274            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2275            create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2276            create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2277        ];
2278
2279        for (join_type, expected_num_rows) in cases {
2280            let join_on = vec![
2281                (
2282                    Arc::new(Column::new("a", 0)) as _,
2283                    Arc::new(Column::new("c", 0)) as _,
2284                ),
2285                (
2286                    Arc::new(Column::new("b", 1)) as _,
2287                    Arc::new(Column::new("d", 1)) as _,
2288                ),
2289            ];
2290
2291            let partial_join_stats = estimate_join_cardinality(
2292                &join_type,
2293                create_stats(Some(1000), left_col_stats.clone(), false),
2294                create_stats(Some(2000), right_col_stats.clone(), false),
2295                &join_on,
2296            )
2297            .unwrap();
2298            assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2299            assert_eq!(
2300                partial_join_stats.column_statistics,
2301                [left_col_stats.clone(), right_col_stats.clone()].concat()
2302            );
2303        }
2304
2305        Ok(())
2306    }
2307
2308    #[test]
2309    fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
2310        // Left table (rows=1000)
2311        //   a: min=0, max=100, distinct=100
2312        //   b: min=0, max=500, distinct=500
2313        //   x: min=1000, max=10000, distinct=None
2314        //
2315        // Right table (rows=2000)
2316        //   c: min=0, max=100, distinct=50
2317        //   d: min=0, max=2000, distinct=2500 (how? some inexact statistics)
2318        //   y: min=0, max=100, distinct=None
2319        //
2320        // Join on a=c, x=y (ignores b/d) where x and y does not intersect
2321
2322        let left_col_stats = vec![
2323            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2324            create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2325            create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2326        ];
2327
2328        let right_col_stats = vec![
2329            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2330            create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2331            create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2332        ];
2333
2334        let join_on = vec![
2335            (
2336                Arc::new(Column::new("a", 0)) as _,
2337                Arc::new(Column::new("c", 0)) as _,
2338            ),
2339            (
2340                Arc::new(Column::new("x", 2)) as _,
2341                Arc::new(Column::new("y", 2)) as _,
2342            ),
2343        ];
2344
2345        let cases = vec![
2346            // Join type, expected cardinality
2347            //
2348            // When an inner join is disjoint, that means it won't
2349            // produce any rows.
2350            (JoinType::Inner, 0),
2351            // But left/right outer joins will produce at least
2352            // the amount of rows from the left/right side.
2353            (JoinType::Left, 1000),
2354            (JoinType::Right, 2000),
2355            // And a full outer join will produce at least the combination
2356            // of the rows above (minus the cardinality of the inner join, which
2357            // is 0).
2358            (JoinType::Full, 3000),
2359        ];
2360
2361        for (join_type, expected_num_rows) in cases {
2362            let partial_join_stats = estimate_join_cardinality(
2363                &join_type,
2364                create_stats(Some(1000), left_col_stats.clone(), true),
2365                create_stats(Some(2000), right_col_stats.clone(), true),
2366                &join_on,
2367            )
2368            .unwrap();
2369            assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2370            assert_eq!(
2371                partial_join_stats.column_statistics,
2372                [left_col_stats.clone(), right_col_stats.clone()].concat()
2373            );
2374        }
2375
2376        Ok(())
2377    }
2378
2379    #[test]
2380    fn test_anti_semi_join_cardinality() -> Result<()> {
2381        let cases: Vec<(JoinType, PartialStats, PartialStats, Option<usize>)> = vec![
2382            // ------------------------------------------------
2383            // | join_type ,                                   |
2384            // | left(rows, min, max, distinct, null_count), |
2385            // | right(rows, min, max, distinct, null_count), |
2386            // | expected,                                    |
2387            // ------------------------------------------------
2388
2389            // Cardinality computation
2390            // =======================
2391            (
2392                JoinType::LeftSemi,
2393                (50, Inexact(10), Inexact(20), Absent, Absent),
2394                (10, Inexact(15), Inexact(25), Absent, Absent),
2395                Some(50),
2396            ),
2397            (
2398                JoinType::RightSemi,
2399                (50, Inexact(10), Inexact(20), Absent, Absent),
2400                (10, Inexact(15), Inexact(25), Absent, Absent),
2401                Some(10),
2402            ),
2403            (
2404                JoinType::LeftSemi,
2405                (10, Absent, Absent, Absent, Absent),
2406                (50, Absent, Absent, Absent, Absent),
2407                Some(10),
2408            ),
2409            (
2410                JoinType::LeftSemi,
2411                (50, Inexact(10), Inexact(20), Absent, Absent),
2412                (10, Inexact(30), Inexact(40), Absent, Absent),
2413                Some(0),
2414            ),
2415            (
2416                JoinType::LeftSemi,
2417                (50, Inexact(10), Absent, Absent, Absent),
2418                (10, Absent, Inexact(5), Absent, Absent),
2419                Some(0),
2420            ),
2421            (
2422                JoinType::LeftSemi,
2423                (50, Absent, Inexact(20), Absent, Absent),
2424                (10, Inexact(30), Absent, Absent, Absent),
2425                Some(0),
2426            ),
2427            (
2428                JoinType::LeftAnti,
2429                (50, Inexact(10), Inexact(20), Absent, Absent),
2430                (10, Inexact(15), Inexact(25), Absent, Absent),
2431                Some(50),
2432            ),
2433            (
2434                JoinType::RightAnti,
2435                (50, Inexact(10), Inexact(20), Absent, Absent),
2436                (10, Inexact(15), Inexact(25), Absent, Absent),
2437                Some(10),
2438            ),
2439            (
2440                JoinType::LeftAnti,
2441                (10, Absent, Absent, Absent, Absent),
2442                (50, Absent, Absent, Absent, Absent),
2443                Some(10),
2444            ),
2445            (
2446                JoinType::LeftAnti,
2447                (50, Inexact(10), Inexact(20), Absent, Absent),
2448                (10, Inexact(30), Inexact(40), Absent, Absent),
2449                Some(50),
2450            ),
2451            (
2452                JoinType::LeftAnti,
2453                (50, Inexact(10), Absent, Absent, Absent),
2454                (10, Absent, Inexact(5), Absent, Absent),
2455                Some(50),
2456            ),
2457            (
2458                JoinType::LeftAnti,
2459                (50, Absent, Inexact(20), Absent, Absent),
2460                (10, Inexact(30), Absent, Absent, Absent),
2461                Some(50),
2462            ),
2463        ];
2464
2465        let join_on = vec![(
2466            Arc::new(Column::new("l_col", 0)) as _,
2467            Arc::new(Column::new("r_col", 0)) as _,
2468        )];
2469
2470        for (join_type, outer_info, inner_info, expected) in cases {
2471            let outer_num_rows = outer_info.0;
2472            let outer_col_stats = vec![create_column_stats(
2473                outer_info.1,
2474                outer_info.2,
2475                outer_info.3,
2476                outer_info.4,
2477            )];
2478
2479            let inner_num_rows = inner_info.0;
2480            let inner_col_stats = vec![create_column_stats(
2481                inner_info.1,
2482                inner_info.2,
2483                inner_info.3,
2484                inner_info.4,
2485            )];
2486
2487            let output_cardinality = estimate_join_cardinality(
2488                &join_type,
2489                Statistics {
2490                    num_rows: Inexact(outer_num_rows),
2491                    total_byte_size: Absent,
2492                    column_statistics: outer_col_stats,
2493                },
2494                Statistics {
2495                    num_rows: Inexact(inner_num_rows),
2496                    total_byte_size: Absent,
2497                    column_statistics: inner_col_stats,
2498                },
2499                &join_on,
2500            )
2501            .map(|cardinality| cardinality.num_rows);
2502
2503            assert_eq!(
2504                output_cardinality, expected,
2505                "failure for join_type: {join_type}"
2506            );
2507        }
2508
2509        Ok(())
2510    }
2511
2512    #[test]
2513    fn test_semi_join_cardinality_absent_rows() -> Result<()> {
2514        let dummy_column_stats =
2515            vec![create_column_stats(Absent, Absent, Absent, Absent)];
2516        let join_on = vec![(
2517            Arc::new(Column::new("l_col", 0)) as _,
2518            Arc::new(Column::new("r_col", 0)) as _,
2519        )];
2520
2521        let absent_outer_estimation = estimate_join_cardinality(
2522            &JoinType::LeftSemi,
2523            Statistics {
2524                num_rows: Absent,
2525                total_byte_size: Absent,
2526                column_statistics: dummy_column_stats.clone(),
2527            },
2528            Statistics {
2529                num_rows: Exact(10),
2530                total_byte_size: Absent,
2531                column_statistics: dummy_column_stats.clone(),
2532            },
2533            &join_on,
2534        );
2535        assert!(
2536            absent_outer_estimation.is_none(),
2537            "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows"
2538        );
2539
2540        let absent_inner_estimation = estimate_join_cardinality(
2541            &JoinType::LeftSemi,
2542            Statistics {
2543                num_rows: Inexact(500),
2544                total_byte_size: Absent,
2545                column_statistics: dummy_column_stats.clone(),
2546            },
2547            Statistics {
2548                num_rows: Absent,
2549                total_byte_size: Absent,
2550                column_statistics: dummy_column_stats.clone(),
2551            },
2552            &join_on,
2553        ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows");
2554
2555        assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows");
2556
2557        let absent_inner_estimation = estimate_join_cardinality(
2558            &JoinType::LeftSemi,
2559            Statistics {
2560                num_rows: Absent,
2561                total_byte_size: Absent,
2562                column_statistics: dummy_column_stats.clone(),
2563            },
2564            Statistics {
2565                num_rows: Absent,
2566                total_byte_size: Absent,
2567                column_statistics: dummy_column_stats,
2568            },
2569            &join_on,
2570        );
2571        assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows");
2572
2573        Ok(())
2574    }
2575
2576    #[test]
2577    fn test_calculate_join_output_ordering() -> Result<()> {
2578        let left_ordering = LexOrdering::new(vec![
2579            PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2580            PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2581            PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2582        ]);
2583        let right_ordering = LexOrdering::new(vec![
2584            PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))),
2585            PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))),
2586        ]);
2587        let join_type = JoinType::Inner;
2588        let left_columns_len = 5;
2589        let maintains_input_orders = [[true, false], [false, true]];
2590        let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
2591
2592        let expected = [
2593            LexOrdering::new(vec![
2594                PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2595                PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2596                PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2597                PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2598                PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2599            ]),
2600            LexOrdering::new(vec![
2601                PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
2602                PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
2603                PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
2604                PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
2605                PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
2606            ]),
2607        ];
2608
2609        for (i, (maintains_input_order, probe_side)) in
2610            maintains_input_orders.iter().zip(probe_sides).enumerate()
2611        {
2612            assert_eq!(
2613                calculate_join_output_ordering(
2614                    left_ordering.as_ref(),
2615                    right_ordering.as_ref(),
2616                    join_type,
2617                    left_columns_len,
2618                    maintains_input_order,
2619                    probe_side,
2620                )?,
2621                expected[i]
2622            );
2623        }
2624
2625        Ok(())
2626    }
2627
2628    fn create_test_batch(num_rows: usize) -> RecordBatch {
2629        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2630        let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32));
2631        RecordBatch::try_new(schema, vec![data]).unwrap()
2632    }
2633
2634    fn assert_split_batches(
2635        batches: Vec<(RecordBatch, bool)>,
2636        batch_size: usize,
2637        num_rows: usize,
2638    ) {
2639        let mut row_count = 0;
2640        for (batch, last) in batches.into_iter() {
2641            assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size));
2642            let column = batch
2643                .column(0)
2644                .as_any()
2645                .downcast_ref::<Int32Array>()
2646                .unwrap();
2647            for i in 0..batch.num_rows() {
2648                assert_eq!(column.value(i), i as i32 + row_count as i32);
2649            }
2650            row_count += batch.num_rows();
2651            assert_eq!(last, row_count == num_rows);
2652        }
2653    }
2654
2655    #[rstest]
2656    #[test]
2657    fn test_batch_splitter(
2658        #[values(1, 3, 11)] batch_size: usize,
2659        #[values(1, 6, 50)] num_rows: usize,
2660    ) {
2661        let mut splitter = BatchSplitter::new(batch_size);
2662        splitter.set_batch(create_test_batch(num_rows));
2663
2664        let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size));
2665        while let Some(batch) = splitter.next() {
2666            batches.push(batch);
2667        }
2668
2669        assert!(splitter.next().is_none());
2670        assert_split_batches(batches, batch_size, num_rows);
2671    }
2672
2673    #[tokio::test]
2674    async fn test_swap_reverting_projection() {
2675        let left_schema = Schema::new(vec![
2676            Field::new("a", DataType::Int32, false),
2677            Field::new("b", DataType::Int32, false),
2678        ]);
2679
2680        let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
2681
2682        let proj = swap_reverting_projection(&left_schema, &right_schema);
2683
2684        assert_eq!(proj.len(), 3);
2685
2686        let proj_expr = &proj[0];
2687        assert_eq!(proj_expr.alias, "a");
2688        assert_col_expr(&proj_expr.expr, "a", 1);
2689
2690        let proj_expr = &proj[1];
2691        assert_eq!(proj_expr.alias, "b");
2692        assert_col_expr(&proj_expr.expr, "b", 2);
2693
2694        let proj_expr = &proj[2];
2695        assert_eq!(proj_expr.alias, "c");
2696        assert_col_expr(&proj_expr.expr, "c", 0);
2697    }
2698
2699    fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
2700        let col = expr
2701            .as_any()
2702            .downcast_ref::<Column>()
2703            .expect("Projection items should be Column expression");
2704        assert_eq!(col.name(), name);
2705        assert_eq!(col.index(), index);
2706    }
2707
2708    #[test]
2709    fn test_join_metadata() -> Result<()> {
2710        let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)])
2711            .with_metadata(HashMap::from([("key".to_string(), "left".to_string())]));
2712
2713        let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)])
2714            .with_metadata(HashMap::from([("key".to_string(), "right".to_string())]));
2715
2716        let (join_schema, _) =
2717            build_join_schema(&left_schema, &right_schema, &JoinType::Left);
2718        assert_eq!(
2719            join_schema.metadata(),
2720            &HashMap::from([("key".to_string(), "left".to_string())])
2721        );
2722        let (join_schema, _) =
2723            build_join_schema(&left_schema, &right_schema, &JoinType::Right);
2724        assert_eq!(
2725            join_schema.metadata(),
2726            &HashMap::from([("key".to_string(), "right".to_string())])
2727        );
2728
2729        Ok(())
2730    }
2731}