use std::cmp::{Ordering, min};
use std::collections::HashSet;
use std::fmt::{self, Debug};
use std::future::Future;
use std::iter::once;
use std::ops::Range;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::joins::SharedBitmapBuilder;
use crate::metrics::{
self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricType,
};
use crate::projection::{ProjectionExec, ProjectionExpr};
use crate::{
ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics,
};
pub use super::join_filter::JoinFilter;
pub use super::join_hash_map::JoinHashMapType;
pub use crate::joins::{JoinOn, JoinOnRef};
use ahash::RandomState;
use arrow::array::{
Array, ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray,
RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array,
builder::UInt64Builder, downcast_array, new_null_array,
};
use arrow::array::{
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int8Array,
Int16Array, Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, StringArray,
StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::kernels::cmp::eq;
use arrow::compute::{self, FilterBuilder, and, take};
use arrow::datatypes::{
ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type,
};
use arrow_ord::cmp::not_distinct;
use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit};
use datafusion_common::cast::as_boolean_array;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::stats::Precision;
use datafusion_common::{
DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult,
not_impl_err, plan_err,
};
use datafusion_expr::Operator;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{
LexOrdering, PhysicalExpr, PhysicalExprRef, add_offset_to_expr,
add_offset_to_physical_sort_exprs,
};
use datafusion_physical_expr_common::datum::compare_op_for_nested;
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use futures::future::{BoxFuture, Shared};
use futures::{FutureExt, ready};
use parking_lot::Mutex;
pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
let left: HashSet<Column> = left
.fields()
.iter()
.enumerate()
.map(|(idx, f)| Column::new(f.name(), idx))
.collect();
let right: HashSet<Column> = right
.fields()
.iter()
.enumerate()
.map(|(idx, f)| Column::new(f.name(), idx))
.collect();
check_join_set_is_valid(&left, &right, on)
}
fn check_join_set_is_valid(
left: &HashSet<Column>,
right: &HashSet<Column>,
on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> Result<()> {
let on_left = &on
.iter()
.flat_map(|on| collect_columns(&on.0))
.collect::<HashSet<_>>();
let left_missing = on_left.difference(left).collect::<HashSet<_>>();
let on_right = &on
.iter()
.flat_map(|on| collect_columns(&on.1))
.collect::<HashSet<_>>();
let right_missing = on_right.difference(right).collect::<HashSet<_>>();
if !left_missing.is_empty() | !right_missing.is_empty() {
return plan_err!(
"The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}"
);
};
Ok(())
}
pub fn adjust_right_output_partitioning(
right_partitioning: &Partitioning,
left_columns_len: usize,
) -> Result<Partitioning> {
let result = match right_partitioning {
Partitioning::Hash(exprs, size) => {
let new_exprs = exprs
.iter()
.map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _))
.collect::<Result<_>>()?;
Partitioning::Hash(new_exprs, *size)
}
result => result.clone(),
};
Ok(result)
}
pub fn calculate_join_output_ordering(
left_ordering: Option<&LexOrdering>,
right_ordering: Option<&LexOrdering>,
join_type: JoinType,
left_columns_len: usize,
maintains_input_order: &[bool],
probe_side: Option<JoinSide>,
) -> Result<Option<LexOrdering>> {
match maintains_input_order {
[true, false] => {
if join_type == JoinType::Inner
&& probe_side == Some(JoinSide::Left)
&& let Some(right_ordering) = right_ordering.cloned()
{
let right_offset = add_offset_to_physical_sort_exprs(
right_ordering,
left_columns_len as _,
)?;
return if let Some(left_ordering) = left_ordering {
let mut result = left_ordering.clone();
result.extend(right_offset);
Ok(Some(result))
} else {
Ok(LexOrdering::new(right_offset))
};
}
Ok(left_ordering.cloned())
}
[false, true] => {
if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) {
return if let Some(right_ordering) = right_ordering.cloned() {
let mut right_offset = add_offset_to_physical_sort_exprs(
right_ordering,
left_columns_len as _,
)?;
if let Some(left_ordering) = left_ordering {
right_offset.extend(left_ordering.clone());
}
Ok(LexOrdering::new(right_offset))
} else {
Ok(left_ordering.cloned())
};
}
let Some(right_ordering) = right_ordering else {
return Ok(None);
};
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
add_offset_to_physical_sort_exprs(
right_ordering.clone(),
left_columns_len as _,
)
.map(LexOrdering::new)
}
_ => Ok(Some(right_ordering.clone())),
}
}
[false, false] => Ok(None),
[true, true] => unreachable!("Cannot maintain ordering of both sides"),
_ => unreachable!("Join operators can not have more than two children"),
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ColumnIndex {
pub index: usize,
pub side: JoinSide,
}
fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
let force_nullable = match join_type {
JoinType::Inner => false,
JoinType::Left => !is_left, JoinType::Right => is_left, JoinType::Full => true, JoinType::LeftSemi => false, JoinType::RightSemi => false, JoinType::LeftAnti => false, JoinType::RightAnti => false, JoinType::LeftMark => false,
JoinType::RightMark => false,
};
if force_nullable {
old_field.clone().with_nullable(true)
} else {
old_field.clone()
}
}
pub fn build_join_schema(
left: &Schema,
right: &Schema,
join_type: &JoinType,
) -> (Schema, Vec<ColumnIndex>) {
let left_fields = || {
left.fields()
.iter()
.map(|f| output_join_field(f, join_type, true))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Left,
},
)
})
};
let right_fields = || {
right
.fields()
.iter()
.map(|f| output_join_field(f, join_type, false))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
})
};
let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
left_fields().chain(right_fields()).unzip()
}
JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(),
JoinType::LeftMark => {
let right_field = once((
Field::new("mark", DataType::Boolean, false),
ColumnIndex {
index: 0,
side: JoinSide::None,
},
));
left_fields().chain(right_field).unzip()
}
JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(),
JoinType::RightMark => {
let left_field = once((
Field::new("mark", DataType::Boolean, false),
ColumnIndex {
index: 0,
side: JoinSide::None,
},
));
right_fields().chain(left_field).unzip()
}
};
let (schema1, schema2) = match join_type {
JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
| JoinType::RightMark => (left, right),
_ => (right, left),
};
let metadata = schema1
.metadata()
.clone()
.into_iter()
.chain(schema2.metadata().clone())
.collect();
(fields.finish().with_metadata(metadata), column_indices)
}
pub(crate) struct OnceAsync<T> {
fut: Mutex<Option<SharedResult<OnceFut<T>>>>,
}
impl<T> Default for OnceAsync<T> {
fn default() -> Self {
Self {
fut: Mutex::new(None),
}
}
}
impl<T> Debug for OnceAsync<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "OnceAsync")
}
}
impl<T: 'static> OnceAsync<T> {
pub(crate) fn try_once<F, Fut>(&self, f: F) -> Result<OnceFut<T>>
where
F: FnOnce() -> Result<Fut>,
Fut: Future<Output = Result<T>> + Send + 'static,
{
self.fut
.lock()
.get_or_insert_with(|| f().map(OnceFut::new).map_err(Arc::new))
.clone()
.map_err(DataFusionError::Shared)
}
}
type OnceFutPending<T> = Shared<BoxFuture<'static, SharedResult<Arc<T>>>>;
pub(crate) struct OnceFut<T> {
state: OnceFutState<T>,
}
impl<T> Clone for OnceFut<T> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
}
}
}
#[derive(Clone, Debug, Default)]
struct PartialJoinStatistics {
pub num_rows: usize,
pub column_statistics: Vec<ColumnStatistics>,
}
pub(crate) fn estimate_join_statistics(
left_stats: Statistics,
right_stats: Statistics,
on: &JoinOn,
join_type: &JoinType,
schema: &Schema,
) -> Result<Statistics> {
let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, on);
let (num_rows, column_statistics) = match join_stats {
Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics),
None => (Precision::Absent, Statistics::unknown_column(schema)),
};
Ok(Statistics {
num_rows,
total_byte_size: Precision::Absent,
column_statistics,
})
}
fn estimate_join_cardinality(
join_type: &JoinType,
left_stats: Statistics,
right_stats: Statistics,
on: &JoinOn,
) -> Option<PartialJoinStatistics> {
let (left_col_stats, right_col_stats) = on
.iter()
.map(|(left, right)| {
match (
left.as_any().downcast_ref::<Column>(),
right.as_any().downcast_ref::<Column>(),
) {
(Some(left), Some(right)) => (
left_stats.column_statistics[left.index()].clone(),
right_stats.column_statistics[right.index()].clone(),
),
_ => (
ColumnStatistics::new_unknown(),
ColumnStatistics::new_unknown(),
),
}
})
.unzip::<_, _, Vec<_>, Vec<_>>();
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
let ij_cardinality = estimate_inner_join_cardinality(
Statistics {
num_rows: left_stats.num_rows,
total_byte_size: Precision::Absent,
column_statistics: left_col_stats,
},
Statistics {
num_rows: right_stats.num_rows,
total_byte_size: Precision::Absent,
column_statistics: right_col_stats,
},
)?;
let cardinality = match join_type {
JoinType::Inner => ij_cardinality,
JoinType::Left => ij_cardinality.max(&left_stats.num_rows),
JoinType::Right => ij_cardinality.max(&right_stats.num_rows),
JoinType::Full => ij_cardinality
.max(&left_stats.num_rows)
.add(&ij_cardinality.max(&right_stats.num_rows))
.sub(&ij_cardinality),
_ => unreachable!(),
};
Some(PartialJoinStatistics {
num_rows: *cardinality.get_value()?,
column_statistics: left_stats
.column_statistics
.into_iter()
.chain(right_stats.column_statistics)
.collect(),
})
}
JoinType::LeftSemi | JoinType::RightSemi => {
let (outer_stats, inner_stats) = match join_type {
JoinType::LeftSemi => (left_stats, right_stats),
_ => (right_stats, left_stats),
};
let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) {
Some(estimation) => *estimation.get_value()?,
None => *outer_stats.num_rows.get_value()?,
};
Some(PartialJoinStatistics {
num_rows: cardinality,
column_statistics: outer_stats.column_statistics,
})
}
JoinType::LeftAnti | JoinType::RightAnti => {
let outer_stats = match join_type {
JoinType::LeftAnti => left_stats,
_ => right_stats,
};
Some(PartialJoinStatistics {
num_rows: *outer_stats.num_rows.get_value()?,
column_statistics: outer_stats.column_statistics,
})
}
JoinType::LeftMark => {
let num_rows = *left_stats.num_rows.get_value()?;
let mut column_statistics = left_stats.column_statistics;
column_statistics.push(ColumnStatistics::new_unknown());
Some(PartialJoinStatistics {
num_rows,
column_statistics,
})
}
JoinType::RightMark => {
let num_rows = *right_stats.num_rows.get_value()?;
let mut column_statistics = right_stats.column_statistics;
column_statistics.push(ColumnStatistics::new_unknown());
Some(PartialJoinStatistics {
num_rows,
column_statistics,
})
}
}
}
fn estimate_inner_join_cardinality(
left_stats: Statistics,
right_stats: Statistics,
) -> Option<Precision<usize>> {
if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) {
return Some(estimation);
};
let Statistics {
num_rows: left_num_rows,
column_statistics: left_column_statistics,
..
} = left_stats;
let Statistics {
num_rows: right_num_rows,
column_statistics: right_column_statistics,
..
} = right_stats;
let mut join_selectivity = Precision::Absent;
for (left_stat, right_stat) in left_column_statistics
.iter()
.zip(right_column_statistics.iter())
{
let left_max_distinct = max_distinct_count(&left_num_rows, left_stat);
let right_max_distinct = max_distinct_count(&right_num_rows, right_stat);
let max_distinct = left_max_distinct.max(&right_max_distinct);
if max_distinct.get_value().is_some() {
join_selectivity = max_distinct;
}
}
let left_num_rows = left_stats.num_rows.get_value()?;
let right_num_rows = right_stats.num_rows.get_value()?;
match join_selectivity {
Precision::Exact(value) if value > 0 => {
Some(Precision::Exact((left_num_rows * right_num_rows) / value))
}
Precision::Inexact(value) if value > 0 => {
Some(Precision::Inexact((left_num_rows * right_num_rows) / value))
}
_ => None,
}
}
fn estimate_disjoint_inputs(
left_stats: &Statistics,
right_stats: &Statistics,
) -> Option<Precision<usize>> {
for (left_stat, right_stat) in left_stats
.column_statistics
.iter()
.zip(right_stats.column_statistics.iter())
{
let left_min_val = left_stat.min_value.get_value();
let right_max_val = right_stat.max_value.get_value();
if left_min_val.is_some()
&& right_max_val.is_some()
&& left_min_val > right_max_val
{
return Some(
if left_stat.min_value.is_exact().unwrap_or(false)
&& right_stat.max_value.is_exact().unwrap_or(false)
{
Precision::Exact(0)
} else {
Precision::Inexact(0)
},
);
}
let left_max_val = left_stat.max_value.get_value();
let right_min_val = right_stat.min_value.get_value();
if left_max_val.is_some()
&& right_min_val.is_some()
&& left_max_val < right_min_val
{
return Some(
if left_stat.max_value.is_exact().unwrap_or(false)
&& right_stat.min_value.is_exact().unwrap_or(false)
{
Precision::Exact(0)
} else {
Precision::Inexact(0)
},
);
}
}
None
}
fn max_distinct_count(
num_rows: &Precision<usize>,
stats: &ColumnStatistics,
) -> Precision<usize> {
match &stats.distinct_count {
&dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc,
_ => {
let result = match num_rows {
Precision::Absent => Precision::Absent,
Precision::Inexact(count) => {
match count.checked_sub(*stats.null_count.get_value().unwrap_or(&0)) {
None => Precision::Inexact(0),
Some(non_null_count) => Precision::Inexact(non_null_count),
}
}
Precision::Exact(count) => {
let count = count - stats.null_count.get_value().unwrap_or(&0);
if stats.null_count.is_exact().unwrap_or(false) {
Precision::Exact(count)
} else {
Precision::Inexact(count)
}
}
};
if let (Some(min), Some(max)) =
(stats.min_value.get_value(), stats.max_value.get_value())
&& let Some(range_dc) = Interval::try_new(min.clone(), max.clone())
.ok()
.and_then(|e| e.cardinality())
{
let range_dc = range_dc as usize;
return if result == Precision::Absent
|| &range_dc < result.get_value().unwrap()
{
if stats.min_value.is_exact().unwrap()
&& stats.max_value.is_exact().unwrap()
{
Precision::Exact(range_dc)
} else {
Precision::Inexact(range_dc)
}
} else {
result
};
}
result
}
}
}
enum OnceFutState<T> {
Pending(OnceFutPending<T>),
Ready(SharedResult<Arc<T>>),
}
impl<T> Clone for OnceFutState<T> {
fn clone(&self) -> Self {
match self {
Self::Pending(p) => Self::Pending(p.clone()),
Self::Ready(r) => Self::Ready(r.clone()),
}
}
}
impl<T: 'static> OnceFut<T> {
pub(crate) fn new<Fut>(fut: Fut) -> Self
where
Fut: Future<Output = Result<T>> + Send + 'static,
{
Self {
state: OnceFutState::Pending(
fut.map(|res| res.map(Arc::new).map_err(Arc::new))
.boxed()
.shared(),
),
}
}
pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll<Result<&T>> {
if let OnceFutState::Pending(fut) = &mut self.state {
let r = ready!(fut.poll_unpin(cx));
self.state = OnceFutState::Ready(r);
}
match &self.state {
OnceFutState::Pending(_) => unreachable!(),
OnceFutState::Ready(r) => Poll::Ready(
r.as_ref()
.map(|r| r.as_ref())
.map_err(DataFusionError::from),
),
}
}
pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll<Result<Arc<T>>> {
if let OnceFutState::Pending(fut) = &mut self.state {
let r = ready!(fut.poll_unpin(cx));
self.state = OnceFutState::Ready(r);
}
match &self.state {
OnceFutState::Pending(_) => unreachable!(),
OnceFutState::Ready(r) => {
Poll::Ready(r.clone().map_err(DataFusionError::Shared))
}
}
}
}
pub(crate) fn need_produce_right_in_final(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Full
| JoinType::Right
| JoinType::RightAnti
| JoinType::RightMark
| JoinType::RightSemi
)
}
pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Left
| JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::LeftMark
| JoinType::Full
)
}
pub(crate) fn get_final_indices_from_shared_bitmap(
shared_bitmap: &SharedBitmapBuilder,
join_type: JoinType,
piecewise: bool,
) -> (UInt64Array, UInt32Array) {
let bitmap = shared_bitmap.lock();
get_final_indices_from_bit_map(&bitmap, join_type, piecewise)
}
pub(crate) fn get_final_indices_from_bit_map(
left_bit_map: &BooleanBufferBuilder,
join_type: JoinType,
piecewise: bool,
) -> (UInt64Array, UInt32Array) {
let left_size = left_bit_map.len();
if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise)
{
let left_indices = (0..left_size as u64).collect::<UInt64Array>();
let right_indices = (0..left_size)
.map(|idx| left_bit_map.get_bit(idx).then_some(0))
.collect::<UInt32Array>();
return (left_indices, right_indices);
}
let left_indices = if join_type == JoinType::LeftSemi
|| (join_type == JoinType::RightSemi && piecewise)
{
(0..left_size)
.filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
} else {
(0..left_size)
.filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
};
let mut builder = UInt32Builder::with_capacity(left_indices.len());
builder.append_nulls(left_indices.len());
let right_indices = builder.finish();
(left_indices, right_indices)
}
#[expect(clippy::too_many_arguments)]
pub(crate) fn apply_join_filter_to_indices(
build_input_buffer: &RecordBatch,
probe_batch: &RecordBatch,
build_indices: UInt64Array,
probe_indices: UInt32Array,
filter: &JoinFilter,
build_side: JoinSide,
max_intermediate_size: Option<usize>,
join_type: JoinType,
) -> Result<(UInt64Array, UInt32Array)> {
if build_indices.is_empty() && probe_indices.is_empty() {
return Ok((build_indices, probe_indices));
};
let filter_result = if let Some(max_size) = max_intermediate_size {
let mut filter_results =
Vec::with_capacity(build_indices.len().div_ceil(max_size));
for i in (0..build_indices.len()).step_by(max_size) {
let end = min(build_indices.len(), i + max_size);
let len = end - i;
let intermediate_batch = build_batch_from_indices(
filter.schema(),
build_input_buffer,
probe_batch,
&build_indices.slice(i, len),
&probe_indices.slice(i, len),
filter.column_indices(),
build_side,
join_type,
)?;
let filter_result = filter
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows())?;
filter_results.push(filter_result);
}
let filter_refs: Vec<&dyn Array> =
filter_results.iter().map(|a| a.as_ref()).collect();
compute::concat(&filter_refs)?
} else {
let intermediate_batch = build_batch_from_indices(
filter.schema(),
build_input_buffer,
probe_batch,
&build_indices,
&probe_indices,
filter.column_indices(),
build_side,
join_type,
)?;
filter
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows())?
};
let mask = as_boolean_array(&filter_result)?;
let left_filtered = compute::filter(&build_indices, mask)?;
let right_filtered = compute::filter(&probe_indices, mask)?;
Ok((
downcast_array(left_filtered.as_ref()),
downcast_array(right_filtered.as_ref()),
))
}
fn new_empty_schema_batch(schema: &Schema, row_count: usize) -> Result<RecordBatch> {
let options = RecordBatchOptions::new().with_row_count(Some(row_count));
Ok(RecordBatch::try_new_with_options(
Arc::new(schema.clone()),
vec![],
&options,
)?)
}
#[expect(clippy::too_many_arguments)]
pub(crate) fn build_batch_from_indices(
schema: &Schema,
build_input_buffer: &RecordBatch,
probe_batch: &RecordBatch,
build_indices: &UInt64Array,
probe_indices: &UInt32Array,
column_indices: &[ColumnIndex],
build_side: JoinSide,
join_type: JoinType,
) -> Result<RecordBatch> {
if schema.fields().is_empty() {
let row_count = match join_type {
JoinType::RightAnti | JoinType::RightSemi => probe_indices.len(),
_ => build_indices.len(),
};
return new_empty_schema_batch(schema, row_count);
}
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for column_index in column_indices {
let array = if column_index.side == JoinSide::None {
Arc::new(compute::is_not_null(probe_indices)?)
} else if column_index.side == build_side {
let array = build_input_buffer.column(column_index.index);
if array.is_empty() || build_indices.null_count() == build_indices.len() {
assert_eq!(build_indices.null_count(), build_indices.len());
new_null_array(array.data_type(), build_indices.len())
} else {
take(array.as_ref(), build_indices, None)?
}
} else {
let array = probe_batch.column(column_index.index);
if array.is_empty() || probe_indices.null_count() == probe_indices.len() {
assert_eq!(probe_indices.null_count(), probe_indices.len());
new_null_array(array.data_type(), probe_indices.len())
} else {
take(array.as_ref(), probe_indices, None)?
}
};
columns.push(array);
}
Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
}
pub(crate) fn build_batch_empty_build_side(
schema: &Schema,
build_batch: &RecordBatch,
probe_batch: &RecordBatch,
column_indices: &[ColumnIndex],
join_type: JoinType,
) -> Result<RecordBatch> {
match join_type {
JoinType::Inner
| JoinType::Left
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::LeftMark => Ok(RecordBatch::new_empty(Arc::new(schema.clone()))),
JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => {
let num_rows = probe_batch.num_rows();
if schema.fields().is_empty() {
return new_empty_schema_batch(schema, num_rows);
}
let mut columns: Vec<Arc<dyn Array>> =
Vec::with_capacity(schema.fields().len());
for column_index in column_indices {
let array = match column_index.side {
JoinSide::Left => new_null_array(
build_batch.column(column_index.index).data_type(),
num_rows,
),
JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)),
JoinSide::None => Arc::new(BooleanArray::new(
BooleanBuffer::new_unset(num_rows),
None,
)),
};
columns.push(array);
}
Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
}
}
}
pub(crate) fn adjust_indices_by_join_type(
left_indices: UInt64Array,
right_indices: UInt32Array,
adjust_range: Range<usize>,
join_type: JoinType,
preserve_order_for_right: bool,
) -> Result<(UInt64Array, UInt32Array)> {
match join_type {
JoinType::Inner => {
Ok((left_indices, right_indices))
}
JoinType::Left => {
Ok((left_indices, right_indices))
}
JoinType::Right => {
append_right_indices(
left_indices,
right_indices,
adjust_range,
preserve_order_for_right,
)
}
JoinType::Full => {
append_right_indices(left_indices, right_indices, adjust_range, false)
}
JoinType::RightSemi => {
let right_indices = get_semi_indices(adjust_range, &right_indices);
Ok((left_indices, right_indices))
}
JoinType::RightAnti => {
let right_indices = get_anti_indices(adjust_range, &right_indices);
Ok((left_indices, right_indices))
}
JoinType::RightMark => {
let right_indices = get_mark_indices(&adjust_range, &right_indices);
let left_indices_vec: Vec<u64> = adjust_range.map(|i| i as u64).collect();
let left_indices = UInt64Array::from(left_indices_vec);
Ok((left_indices, right_indices))
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
Ok((
UInt64Array::from_iter_values(vec![]),
UInt32Array::from_iter_values(vec![]),
))
}
}
}
pub(crate) fn append_right_indices(
left_indices: UInt64Array,
right_indices: UInt32Array,
adjust_range: Range<usize>,
preserve_order_for_right: bool,
) -> Result<(UInt64Array, UInt32Array)> {
if preserve_order_for_right {
Ok(append_probe_indices_in_order(
&left_indices,
&right_indices,
adjust_range,
))
} else {
let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices);
if right_unmatched_indices.is_empty() {
Ok((left_indices, right_indices))
} else {
let mut new_left_indices_builder =
left_indices.into_builder().unwrap_or_else(|left_indices| {
let mut builder = UInt64Builder::with_capacity(
left_indices.len() + right_unmatched_indices.len(),
);
debug_assert_eq!(
left_indices.null_count(),
0,
"expected left indices to have no nulls"
);
builder.append_slice(left_indices.values());
builder
});
new_left_indices_builder.append_nulls(right_unmatched_indices.len());
let new_left_indices = UInt64Array::from(new_left_indices_builder.finish());
let mut new_right_indices_builder = right_indices
.into_builder()
.unwrap_or_else(|right_indices| {
let mut builder = UInt32Builder::with_capacity(
right_indices.len() + right_unmatched_indices.len(),
);
debug_assert_eq!(
right_indices.null_count(),
0,
"expected right indices to have no nulls"
);
builder.append_slice(right_indices.values());
builder
});
debug_assert_eq!(
right_unmatched_indices.null_count(),
0,
"expected right unmatched indices to have no nulls"
);
new_right_indices_builder.append_slice(right_unmatched_indices.values());
let new_right_indices = UInt32Array::from(new_right_indices_builder.finish());
Ok((new_left_indices, new_right_indices))
}
}
}
pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
range: Range<usize>,
input_indices: &PrimitiveArray<T>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let bitmap = build_range_bitmap(&range, input_indices);
let offset = range.start;
(range)
.filter_map(|idx| {
(!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
})
.collect()
}
pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
range: Range<usize>,
input_indices: &PrimitiveArray<T>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let bitmap = build_range_bitmap(&range, input_indices);
let offset = range.start;
(range)
.filter_map(|idx| {
(bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
})
.collect()
}
pub(crate) fn get_mark_indices<T: ArrowPrimitiveType>(
range: &Range<usize>,
input_indices: &PrimitiveArray<T>,
) -> PrimitiveArray<UInt32Type>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let mut bitmap = build_range_bitmap(range, input_indices);
PrimitiveArray::new(
vec![0; range.len()].into(),
Some(NullBuffer::new(bitmap.finish())),
)
}
fn build_range_bitmap<T: ArrowPrimitiveType>(
range: &Range<usize>,
input: &PrimitiveArray<T>,
) -> BooleanBufferBuilder {
let mut builder = BooleanBufferBuilder::new(range.len());
builder.append_n(range.len(), false);
input.iter().flatten().for_each(|v| {
let idx = v.as_usize();
if range.contains(&idx) {
builder.set_bit(idx - range.start, true);
}
});
builder
}
fn append_probe_indices_in_order(
build_indices: &PrimitiveArray<UInt64Type>,
probe_indices: &PrimitiveArray<UInt32Type>,
range: Range<usize>,
) -> (PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>) {
let mut new_build_indices = UInt64Builder::new();
let mut new_probe_indices = UInt32Builder::new();
let mut prev_index = range.start as u32;
debug_assert!(build_indices.len() == probe_indices.len());
for (build_index, probe_index) in build_indices
.values()
.into_iter()
.zip(probe_indices.values().into_iter())
{
for value in prev_index..*probe_index {
new_probe_indices.append_value(value);
new_build_indices.append_null();
}
new_probe_indices.append_value(*probe_index);
new_build_indices.append_value(*build_index);
prev_index = probe_index + 1;
}
for value in prev_index..range.end as u32 {
new_probe_indices.append_value(value);
new_build_indices.append_null();
}
(new_build_indices.finish(), new_probe_indices.finish())
}
#[derive(Clone, Debug)]
pub(crate) struct BuildProbeJoinMetrics {
pub(crate) baseline: BaselineMetrics,
pub(crate) build_time: metrics::Time,
pub(crate) build_input_batches: metrics::Count,
pub(crate) build_input_rows: metrics::Count,
pub(crate) build_mem_used: metrics::Gauge,
pub(crate) join_time: metrics::Time,
pub(crate) input_batches: metrics::Count,
pub(crate) input_rows: metrics::Count,
pub(crate) probe_hit_rate: metrics::RatioMetrics,
pub(crate) avg_fanout: metrics::RatioMetrics,
}
impl Drop for BuildProbeJoinMetrics {
fn drop(&mut self) {
self.baseline.elapsed_compute().add(&self.build_time);
self.baseline.elapsed_compute().add(&self.join_time);
}
}
impl BuildProbeJoinMetrics {
pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
let baseline = BaselineMetrics::new(metrics, partition);
let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition);
let build_input_batches =
MetricBuilder::new(metrics).counter("build_input_batches", partition);
let build_input_rows =
MetricBuilder::new(metrics).counter("build_input_rows", partition);
let build_mem_used =
MetricBuilder::new(metrics).gauge("build_mem_used", partition);
let input_batches =
MetricBuilder::new(metrics).counter("input_batches", partition);
let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
let probe_hit_rate = MetricBuilder::new(metrics)
.with_type(MetricType::SUMMARY)
.ratio_metrics("probe_hit_rate", partition);
let avg_fanout = MetricBuilder::new(metrics)
.with_type(MetricType::SUMMARY)
.ratio_metrics("avg_fanout", partition);
Self {
build_time,
build_input_batches,
build_input_rows,
build_mem_used,
join_time,
input_batches,
input_rows,
baseline,
probe_hit_rate,
avg_fanout,
}
}
}
#[macro_export]
macro_rules! handle_state {
($match_case:expr) => {
match $match_case {
Ok(StatefulStreamResult::Continue) => continue,
Ok(StatefulStreamResult::Ready(result)) => {
Poll::Ready(Ok(result).transpose())
}
Err(e) => Poll::Ready(Some(Err(e))),
}
};
}
pub enum StatefulStreamResult<T> {
Ready(T),
Continue,
}
pub(crate) fn symmetric_join_output_partitioning(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
join_type: &JoinType,
) -> Result<Partitioning> {
let left_columns_len = left.schema().fields.len();
let left_partitioning = left.output_partitioning();
let right_partitioning = right.output_partitioning();
let result = match join_type {
JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
left_partitioning.clone()
}
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
right_partitioning.clone()
}
JoinType::Inner | JoinType::Right => {
adjust_right_output_partitioning(right_partitioning, left_columns_len)?
}
JoinType::Full => {
Partitioning::UnknownPartitioning(right_partitioning.partition_count())
}
};
Ok(result)
}
pub(crate) fn asymmetric_join_output_partitioning(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
join_type: &JoinType,
) -> Result<Partitioning> {
let result = match join_type {
JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
right.output_partitioning(),
left.schema().fields().len(),
)?,
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
right.output_partitioning().clone()
}
JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::Full
| JoinType::LeftMark => Partitioning::UnknownPartitioning(
right.output_partitioning().partition_count(),
),
};
Ok(result)
}
pub(crate) trait BatchTransformer: Debug + Clone {
fn set_batch(&mut self, batch: RecordBatch);
fn next(&mut self) -> Option<(RecordBatch, bool)>;
}
#[derive(Debug, Clone)]
pub(crate) struct NoopBatchTransformer {
batch: Option<RecordBatch>,
}
impl NoopBatchTransformer {
pub fn new() -> Self {
Self { batch: None }
}
}
impl BatchTransformer for NoopBatchTransformer {
fn set_batch(&mut self, batch: RecordBatch) {
self.batch = Some(batch);
}
fn next(&mut self) -> Option<(RecordBatch, bool)> {
self.batch.take().map(|batch| (batch, true))
}
}
#[derive(Debug, Clone)]
pub(crate) struct BatchSplitter {
batch: Option<RecordBatch>,
batch_size: usize,
row_index: usize,
}
impl BatchSplitter {
pub(crate) fn new(batch_size: usize) -> Self {
Self {
batch: None,
batch_size,
row_index: 0,
}
}
}
impl BatchTransformer for BatchSplitter {
fn set_batch(&mut self, batch: RecordBatch) {
self.batch = Some(batch);
self.row_index = 0;
}
fn next(&mut self) -> Option<(RecordBatch, bool)> {
let Some(batch) = &self.batch else {
return None;
};
let remaining_rows = batch.num_rows() - self.row_index;
let rows_to_slice = remaining_rows.min(self.batch_size);
let sliced_batch = batch.slice(self.row_index, rows_to_slice);
self.row_index += rows_to_slice;
let mut last = false;
if self.row_index >= batch.num_rows() {
self.batch = None;
last = true;
}
Some((sliced_batch, last))
}
}
pub fn reorder_output_after_swap(
plan: Arc<dyn ExecutionPlan>,
left_schema: &Schema,
right_schema: &Schema,
) -> Result<Arc<dyn ExecutionPlan>> {
let proj = ProjectionExec::try_new(
swap_reverting_projection(left_schema, right_schema),
plan,
)?;
Ok(Arc::new(proj))
}
fn swap_reverting_projection(
left_schema: &Schema,
right_schema: &Schema,
) -> Vec<ProjectionExpr> {
let right_cols =
right_schema
.fields()
.iter()
.enumerate()
.map(|(i, f)| ProjectionExpr {
expr: Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
alias: f.name().to_owned(),
});
let right_len = right_cols.len();
let left_cols =
left_schema
.fields()
.iter()
.enumerate()
.map(|(i, f)| ProjectionExpr {
expr: Arc::new(Column::new(f.name(), right_len + i))
as Arc<dyn PhysicalExpr>,
alias: f.name().to_owned(),
});
left_cols.chain(right_cols).collect()
}
pub fn swap_join_projection(
left_schema_len: usize,
right_schema_len: usize,
projection: Option<&[usize]>,
join_type: &JoinType,
) -> Option<Vec<usize>> {
match join_type {
JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::RightAnti
| JoinType::RightSemi
| JoinType::LeftMark
| JoinType::RightMark => projection.map(|p| p.to_vec()),
_ => projection.map(|p| {
p.iter()
.map(|i| {
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
}),
}
}
#[expect(clippy::too_many_arguments)]
pub fn update_hash(
on: &[PhysicalExprRef],
batch: &RecordBatch,
hash_map: &mut dyn JoinHashMapType,
offset: usize,
random_state: &RandomState,
hashes_buffer: &mut [u64],
deleted_offset: usize,
fifo_hashmap: bool,
) -> Result<()> {
let keys_values = evaluate_expressions_to_arrays(on, batch)?;
let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
hash_map.extend_zero(batch.num_rows());
let hash_values_iter = hash_values
.iter()
.enumerate()
.map(|(i, val)| (i + offset, val));
if fifo_hashmap {
hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset);
} else {
hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset);
}
Ok(())
}
pub(super) fn equal_rows_arr(
indices_left: &UInt64Array,
indices_right: &UInt32Array,
left_arrays: &[ArrayRef],
right_arrays: &[ArrayRef],
null_equality: NullEquality,
) -> Result<(UInt64Array, UInt32Array)> {
let mut iter = left_arrays.iter().zip(right_arrays.iter());
let Some((first_left, first_right)) = iter.next() else {
return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into()));
};
let arr_left = take(first_left.as_ref(), indices_left, None)?;
let arr_right = take(first_right.as_ref(), indices_right, None)?;
let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?;
equal = iter
.map(|(left, right)| {
let arr_left = take(left.as_ref(), indices_left, None)?;
let arr_right = take(right.as_ref(), indices_right, None)?;
eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality)
})
.try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
let filter_builder = FilterBuilder::new(&equal).optimize().build();
let left_filtered = filter_builder.filter(indices_left)?;
let right_filtered = filter_builder.filter(indices_right)?;
Ok((
downcast_array(left_filtered.as_ref()),
downcast_array(right_filtered.as_ref()),
))
}
fn eq_dyn_null(
left: &dyn Array,
right: &dyn Array,
null_equality: NullEquality,
) -> Result<BooleanArray, ArrowError> {
if left.data_type().is_nested() {
let op = match null_equality {
NullEquality::NullEqualsNothing => Operator::Eq,
NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
};
return Ok(compare_op_for_nested(op, &left, &right)?);
}
match null_equality {
NullEquality::NullEqualsNothing => eq(&left, &right),
NullEquality::NullEqualsNull => not_distinct(&left, &right),
}
}
pub fn compare_join_arrays(
left_arrays: &[ArrayRef],
left: usize,
right_arrays: &[ArrayRef],
right: usize,
sort_options: &[SortOptions],
null_equality: NullEquality,
) -> Result<Ordering> {
let mut res = Ordering::Equal;
for ((left_array, right_array), sort_options) in
left_arrays.iter().zip(right_arrays).zip(sort_options)
{
macro_rules! compare_value {
($T:ty) => {{
let left_array = left_array.as_any().downcast_ref::<$T>().unwrap();
let right_array = right_array.as_any().downcast_ref::<$T>().unwrap();
match (left_array.is_null(left), right_array.is_null(right)) {
(false, false) => {
let left_value = &left_array.value(left);
let right_value = &right_array.value(right);
res = left_value.partial_cmp(right_value).unwrap();
if sort_options.descending {
res = res.reverse();
}
}
(true, false) => {
res = if sort_options.nulls_first {
Ordering::Less
} else {
Ordering::Greater
};
}
(false, true) => {
res = if sort_options.nulls_first {
Ordering::Greater
} else {
Ordering::Less
};
}
_ => {
res = match null_equality {
NullEquality::NullEqualsNothing => Ordering::Less,
NullEquality::NullEqualsNull => Ordering::Equal,
};
}
}
}};
}
match left_array.data_type() {
DataType::Null => {}
DataType::Boolean => compare_value!(BooleanArray),
DataType::Int8 => compare_value!(Int8Array),
DataType::Int16 => compare_value!(Int16Array),
DataType::Int32 => compare_value!(Int32Array),
DataType::Int64 => compare_value!(Int64Array),
DataType::UInt8 => compare_value!(UInt8Array),
DataType::UInt16 => compare_value!(UInt16Array),
DataType::UInt32 => compare_value!(UInt32Array),
DataType::UInt64 => compare_value!(UInt64Array),
DataType::Float32 => compare_value!(Float32Array),
DataType::Float64 => compare_value!(Float64Array),
DataType::Binary => compare_value!(BinaryArray),
DataType::BinaryView => compare_value!(BinaryViewArray),
DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
DataType::LargeBinary => compare_value!(LargeBinaryArray),
DataType::Utf8 => compare_value!(StringArray),
DataType::Utf8View => compare_value!(StringViewArray),
DataType::LargeUtf8 => compare_value!(LargeStringArray),
DataType::Decimal128(..) => compare_value!(Decimal128Array),
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => compare_value!(TimestampSecondArray),
TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
},
DataType::Date32 => compare_value!(Date32Array),
DataType::Date64 => compare_value!(Date64Array),
dt => {
return not_impl_err!(
"Unsupported data type in sort merge join comparator: {}",
dt
);
}
}
if !res.is_eq() {
break;
}
}
Ok(res)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::pin::Pin;
use super::*;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Fields};
use arrow::error::{ArrowError, Result as ArrowResult};
use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
use datafusion_common::{ScalarValue, arrow_datafusion_err, arrow_err};
use datafusion_physical_expr::PhysicalSortExpr;
use rstest::rstest;
fn check(
left: &[Column],
right: &[Column],
on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> Result<()> {
let left = left
.iter()
.map(|x| x.to_owned())
.collect::<HashSet<Column>>();
let right = right
.iter()
.map(|x| x.to_owned())
.collect::<HashSet<Column>>();
check_join_set_is_valid(&left, &right, on)
}
#[test]
fn check_valid() -> Result<()> {
let left = vec![Column::new("a", 0), Column::new("b1", 1)];
let right = vec![Column::new("a", 0), Column::new("b2", 1)];
let on = &[(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("a", 0)) as _,
)];
check(&left, &right, on)?;
Ok(())
}
#[test]
fn check_not_in_right() {
let left = vec![Column::new("a", 0), Column::new("b", 1)];
let right = vec![Column::new("b", 0)];
let on = &[(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("a", 0)) as _,
)];
assert!(check(&left, &right, on).is_err());
}
#[tokio::test]
async fn check_error_nesting() {
let once_fut = OnceFut::<()>::new(async {
arrow_err!(ArrowError::CsvError("some error".to_string()))
});
struct TestFut(OnceFut<()>);
impl Future for TestFut {
type Output = ArrowResult<()>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
match ready!(self.0.get(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e.into())),
}
}
}
let res = TestFut(once_fut).await;
let arrow_err_from_fut = res.expect_err("once_fut always return error");
let wrapped_err = DataFusionError::from(arrow_err_from_fut);
let root_err = wrapped_err.find_root();
let _expected =
arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned()));
assert!(matches!(root_err, _expected))
}
#[test]
fn check_not_in_left() {
let left = vec![Column::new("b", 0)];
let right = vec![Column::new("a", 0)];
let on = &[(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("a", 0)) as _,
)];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_collision() {
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("a", 0), Column::new("b", 1)];
let on = &[(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("b", 1)) as _,
)];
assert!(check(&left, &right, on).is_ok());
}
#[test]
fn check_in_right() {
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("b", 0)];
let on = &[(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("b", 0)) as _,
)];
assert!(check(&left, &right, on).is_ok());
}
#[test]
fn test_join_schema() -> Result<()> {
let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
let cases = vec![
(&a, &b, JoinType::Inner, &a, &b),
(&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
(&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
(&a, &b, JoinType::Left, &a, &b_nulls),
(&a, &b_nulls, JoinType::Left, &a, &b_nulls),
(&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
(&a, &b, JoinType::Right, &a_nulls, &b),
(&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Right, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
];
for (left_in, right_in, join_type, left_out, right_out) in cases {
let (schema, _) = build_join_schema(left_in, right_in, &join_type);
let expected_fields = left_out
.fields()
.iter()
.cloned()
.chain(right_out.fields().iter().cloned())
.collect::<Fields>();
let expected_schema = Schema::new(expected_fields);
assert_eq!(
schema,
expected_schema,
"Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
left_in.fields()[0].name(),
left_in.fields()[0].is_nullable(),
right_in.fields()[0].name(),
right_in.fields()[0].is_nullable(),
join_type
);
}
Ok(())
}
fn create_stats(
num_rows: Option<usize>,
column_stats: Vec<ColumnStatistics>,
is_exact: bool,
) -> Statistics {
Statistics {
num_rows: if is_exact {
num_rows.map(Exact)
} else {
num_rows.map(Inexact)
}
.unwrap_or(Absent),
column_statistics: column_stats,
total_byte_size: Absent,
}
}
fn create_column_stats(
min: Precision<i64>,
max: Precision<i64>,
distinct_count: Precision<usize>,
null_count: Precision<usize>,
) -> ColumnStatistics {
ColumnStatistics {
distinct_count,
min_value: min.map(ScalarValue::from),
max_value: max.map(ScalarValue::from),
sum_value: Absent,
null_count,
byte_size: Absent,
}
}
type PartialStats = (
usize,
Precision<i64>,
Precision<i64>,
Precision<usize>,
Precision<usize>,
);
#[test]
fn test_inner_join_cardinality_single_column() -> Result<()> {
let cases: Vec<(PartialStats, PartialStats, Option<Precision<usize>>)> = vec![
(
(10, Inexact(1), Inexact(10), Absent, Absent),
(10, Inexact(1), Inexact(10), Absent, Absent),
Some(Inexact(10)),
),
(
(10, Inexact(6), Inexact(10), Absent, Absent),
(10, Inexact(8), Inexact(10), Absent, Absent),
Some(Inexact(20)),
),
(
(10, Inexact(8), Inexact(10), Absent, Absent),
(10, Inexact(6), Inexact(10), Absent, Absent),
Some(Inexact(20)),
),
(
(10, Inexact(1), Inexact(15), Absent, Absent),
(20, Inexact(1), Inexact(40), Absent, Absent),
Some(Inexact(10)),
),
(
(10, Inexact(1), Inexact(10), Inexact(10), Absent),
(10, Inexact(1), Inexact(10), Inexact(10), Absent),
Some(Inexact(10)),
),
(
(10, Inexact(1), Inexact(3), Inexact(10), Absent),
(10, Inexact(1), Inexact(3), Inexact(10), Absent),
Some(Inexact(10)),
),
(
(10, Inexact(1), Inexact(10), Inexact(5), Absent),
(10, Inexact(1), Inexact(10), Inexact(2), Absent),
Some(Inexact(20)),
),
(
(10, Inexact(1), Inexact(10), Inexact(2), Absent),
(10, Inexact(1), Inexact(10), Inexact(5), Absent),
Some(Inexact(20)),
),
(
(10, Inexact(-5), Inexact(5), Absent, Absent),
(10, Inexact(1), Inexact(5), Absent, Absent),
Some(Inexact(10)),
),
(
(10, Inexact(-25), Inexact(-20), Absent, Absent),
(10, Inexact(-25), Inexact(-15), Absent, Absent),
Some(Inexact(10)),
),
(
(10, Inexact(-10), Inexact(0), Absent, Absent),
(10, Inexact(0), Inexact(10), Inexact(5), Absent),
Some(Inexact(10)),
),
(
(10, Inexact(1), Inexact(1), Absent, Absent),
(10, Inexact(1), Inexact(1), Absent, Absent),
Some(Inexact(100)),
),
(
(10, Absent, Absent, Absent, Absent),
(10, Absent, Absent, Absent, Absent),
Some(Inexact(10)),
),
(
(10, Absent, Absent, Inexact(3), Absent),
(10, Absent, Absent, Inexact(3), Absent),
Some(Inexact(33)),
),
(
(10, Inexact(2), Absent, Inexact(3), Absent),
(10, Absent, Inexact(5), Inexact(3), Absent),
Some(Inexact(33)),
),
(
(10, Absent, Inexact(3), Inexact(3), Absent),
(10, Inexact(1), Absent, Inexact(3), Absent),
Some(Inexact(33)),
),
(
(10, Absent, Inexact(3), Absent, Absent),
(10, Inexact(1), Absent, Absent, Absent),
Some(Inexact(10)),
),
(
(10, Absent, Inexact(4), Absent, Absent),
(10, Inexact(5), Absent, Absent, Absent),
Some(Inexact(0)),
),
(
(10, Inexact(0), Inexact(10), Absent, Absent),
(10, Inexact(11), Inexact(20), Absent, Absent),
Some(Inexact(0)),
),
(
(10, Inexact(11), Inexact(20), Absent, Absent),
(10, Inexact(0), Inexact(10), Absent, Absent),
Some(Inexact(0)),
),
(
(10, Inexact(1), Inexact(10), Inexact(0), Absent),
(10, Inexact(1), Inexact(10), Inexact(0), Absent),
None,
),
(
(0, Inexact(1), Inexact(10), Absent, Exact(5)),
(10, Inexact(1), Inexact(10), Absent, Absent),
Some(Inexact(0)),
),
];
for (left_info, right_info, expected_cardinality) in cases {
let left_num_rows = left_info.0;
let left_col_stats = vec![create_column_stats(
left_info.1,
left_info.2,
left_info.3,
left_info.4,
)];
let right_num_rows = right_info.0;
let right_col_stats = vec![create_column_stats(
right_info.1,
right_info.2,
right_info.3,
right_info.4,
)];
assert_eq!(
estimate_inner_join_cardinality(
Statistics {
num_rows: Inexact(left_num_rows),
total_byte_size: Absent,
column_statistics: left_col_stats.clone(),
},
Statistics {
num_rows: Inexact(right_num_rows),
total_byte_size: Absent,
column_statistics: right_col_stats.clone(),
},
),
expected_cardinality.clone()
);
let join_type = JoinType::Inner;
let join_on = vec![(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("b", 0)) as _,
)];
let partial_join_stats = estimate_join_cardinality(
&join_type,
create_stats(Some(left_num_rows), left_col_stats.clone(), false),
create_stats(Some(right_num_rows), right_col_stats.clone(), false),
&join_on,
);
assert_eq!(
partial_join_stats.clone().map(|s| Inexact(s.num_rows)),
expected_cardinality.clone()
);
assert_eq!(
partial_join_stats.map(|s| s.column_statistics),
expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat())
);
}
Ok(())
}
#[test]
fn test_inner_join_cardinality_multiple_column() -> Result<()> {
let left_col_stats = vec![
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
];
let right_col_stats = vec![
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
];
assert_eq!(
estimate_inner_join_cardinality(
Statistics {
num_rows: Inexact(400),
total_byte_size: Absent,
column_statistics: left_col_stats,
},
Statistics {
num_rows: Inexact(400),
total_byte_size: Absent,
column_statistics: right_col_stats,
},
),
Some(Inexact((400 * 400) / 200))
);
Ok(())
}
#[test]
fn test_inner_join_cardinality_decimal_range() -> Result<()> {
let left_col_stats = vec![ColumnStatistics {
distinct_count: Absent,
min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)),
max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)),
..Default::default()
}];
let right_col_stats = vec![ColumnStatistics {
distinct_count: Absent,
min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)),
max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)),
..Default::default()
}];
assert_eq!(
estimate_inner_join_cardinality(
Statistics {
num_rows: Inexact(100),
total_byte_size: Absent,
column_statistics: left_col_stats,
},
Statistics {
num_rows: Inexact(100),
total_byte_size: Absent,
column_statistics: right_col_stats,
},
),
Some(Inexact(100))
);
Ok(())
}
#[test]
fn test_join_cardinality() -> Result<()> {
let cases = vec![
(JoinType::Inner, 800),
(JoinType::Left, 1000),
(JoinType::Right, 2000),
(JoinType::Full, 2200),
];
let left_col_stats = vec![
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
];
let right_col_stats = vec![
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
];
for (join_type, expected_num_rows) in cases {
let join_on = vec![
(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("c", 0)) as _,
),
(
Arc::new(Column::new("b", 1)) as _,
Arc::new(Column::new("d", 1)) as _,
),
];
let partial_join_stats = estimate_join_cardinality(
&join_type,
create_stats(Some(1000), left_col_stats.clone(), false),
create_stats(Some(2000), right_col_stats.clone(), false),
&join_on,
)
.unwrap();
assert_eq!(partial_join_stats.num_rows, expected_num_rows);
assert_eq!(
partial_join_stats.column_statistics,
[left_col_stats.clone(), right_col_stats.clone()].concat()
);
}
Ok(())
}
#[test]
fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
let left_col_stats = vec![
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
];
let right_col_stats = vec![
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
];
let join_on = vec![
(
Arc::new(Column::new("a", 0)) as _,
Arc::new(Column::new("c", 0)) as _,
),
(
Arc::new(Column::new("x", 2)) as _,
Arc::new(Column::new("y", 2)) as _,
),
];
let cases = vec![
(JoinType::Inner, 0),
(JoinType::Left, 1000),
(JoinType::Right, 2000),
(JoinType::Full, 3000),
];
for (join_type, expected_num_rows) in cases {
let partial_join_stats = estimate_join_cardinality(
&join_type,
create_stats(Some(1000), left_col_stats.clone(), true),
create_stats(Some(2000), right_col_stats.clone(), true),
&join_on,
)
.unwrap();
assert_eq!(partial_join_stats.num_rows, expected_num_rows);
assert_eq!(
partial_join_stats.column_statistics,
[left_col_stats.clone(), right_col_stats.clone()].concat()
);
}
Ok(())
}
#[test]
fn test_anti_semi_join_cardinality() -> Result<()> {
let cases: Vec<(JoinType, PartialStats, PartialStats, Option<usize>)> = vec![
(
JoinType::LeftSemi,
(50, Inexact(10), Inexact(20), Absent, Absent),
(10, Inexact(15), Inexact(25), Absent, Absent),
Some(50),
),
(
JoinType::RightSemi,
(50, Inexact(10), Inexact(20), Absent, Absent),
(10, Inexact(15), Inexact(25), Absent, Absent),
Some(10),
),
(
JoinType::LeftSemi,
(10, Absent, Absent, Absent, Absent),
(50, Absent, Absent, Absent, Absent),
Some(10),
),
(
JoinType::LeftSemi,
(50, Inexact(10), Inexact(20), Absent, Absent),
(10, Inexact(30), Inexact(40), Absent, Absent),
Some(0),
),
(
JoinType::LeftSemi,
(50, Inexact(10), Absent, Absent, Absent),
(10, Absent, Inexact(5), Absent, Absent),
Some(0),
),
(
JoinType::LeftSemi,
(50, Absent, Inexact(20), Absent, Absent),
(10, Inexact(30), Absent, Absent, Absent),
Some(0),
),
(
JoinType::LeftAnti,
(50, Inexact(10), Inexact(20), Absent, Absent),
(10, Inexact(15), Inexact(25), Absent, Absent),
Some(50),
),
(
JoinType::RightAnti,
(50, Inexact(10), Inexact(20), Absent, Absent),
(10, Inexact(15), Inexact(25), Absent, Absent),
Some(10),
),
(
JoinType::LeftAnti,
(10, Absent, Absent, Absent, Absent),
(50, Absent, Absent, Absent, Absent),
Some(10),
),
(
JoinType::LeftAnti,
(50, Inexact(10), Inexact(20), Absent, Absent),
(10, Inexact(30), Inexact(40), Absent, Absent),
Some(50),
),
(
JoinType::LeftAnti,
(50, Inexact(10), Absent, Absent, Absent),
(10, Absent, Inexact(5), Absent, Absent),
Some(50),
),
(
JoinType::LeftAnti,
(50, Absent, Inexact(20), Absent, Absent),
(10, Inexact(30), Absent, Absent, Absent),
Some(50),
),
];
let join_on = vec![(
Arc::new(Column::new("l_col", 0)) as _,
Arc::new(Column::new("r_col", 0)) as _,
)];
for (join_type, outer_info, inner_info, expected) in cases {
let outer_num_rows = outer_info.0;
let outer_col_stats = vec![create_column_stats(
outer_info.1,
outer_info.2,
outer_info.3,
outer_info.4,
)];
let inner_num_rows = inner_info.0;
let inner_col_stats = vec![create_column_stats(
inner_info.1,
inner_info.2,
inner_info.3,
inner_info.4,
)];
let output_cardinality = estimate_join_cardinality(
&join_type,
Statistics {
num_rows: Inexact(outer_num_rows),
total_byte_size: Absent,
column_statistics: outer_col_stats,
},
Statistics {
num_rows: Inexact(inner_num_rows),
total_byte_size: Absent,
column_statistics: inner_col_stats,
},
&join_on,
)
.map(|cardinality| cardinality.num_rows);
assert_eq!(
output_cardinality, expected,
"failure for join_type: {join_type}"
);
}
Ok(())
}
#[test]
fn test_semi_join_cardinality_absent_rows() -> Result<()> {
let dummy_column_stats =
vec![create_column_stats(Absent, Absent, Absent, Absent)];
let join_on = vec![(
Arc::new(Column::new("l_col", 0)) as _,
Arc::new(Column::new("r_col", 0)) as _,
)];
let absent_outer_estimation = estimate_join_cardinality(
&JoinType::LeftSemi,
Statistics {
num_rows: Absent,
total_byte_size: Absent,
column_statistics: dummy_column_stats.clone(),
},
Statistics {
num_rows: Exact(10),
total_byte_size: Absent,
column_statistics: dummy_column_stats.clone(),
},
&join_on,
);
assert!(
absent_outer_estimation.is_none(),
"Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows"
);
let absent_inner_estimation = estimate_join_cardinality(
&JoinType::LeftSemi,
Statistics {
num_rows: Inexact(500),
total_byte_size: Absent,
column_statistics: dummy_column_stats.clone(),
},
Statistics {
num_rows: Absent,
total_byte_size: Absent,
column_statistics: dummy_column_stats.clone(),
},
&join_on,
).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows");
assert_eq!(
absent_inner_estimation.num_rows, 500,
"Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows"
);
let absent_inner_estimation = estimate_join_cardinality(
&JoinType::LeftSemi,
Statistics {
num_rows: Absent,
total_byte_size: Absent,
column_statistics: dummy_column_stats.clone(),
},
Statistics {
num_rows: Absent,
total_byte_size: Absent,
column_statistics: dummy_column_stats,
},
&join_on,
);
assert!(
absent_inner_estimation.is_none(),
"Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows"
);
Ok(())
}
#[test]
fn test_calculate_join_output_ordering() -> Result<()> {
let left_ordering = LexOrdering::new(vec![
PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
]);
let right_ordering = LexOrdering::new(vec![
PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))),
PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))),
]);
let join_type = JoinType::Inner;
let left_columns_len = 5;
let maintains_input_orders = [[true, false], [false, true]];
let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
let expected = [
LexOrdering::new(vec![
PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
]),
LexOrdering::new(vec![
PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))),
PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))),
PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))),
PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))),
PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))),
]),
];
for (i, (maintains_input_order, probe_side)) in
maintains_input_orders.iter().zip(probe_sides).enumerate()
{
assert_eq!(
calculate_join_output_ordering(
left_ordering.as_ref(),
right_ordering.as_ref(),
join_type,
left_columns_len,
maintains_input_order,
probe_side,
)?,
expected[i]
);
}
Ok(())
}
fn create_test_batch(num_rows: usize) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32));
RecordBatch::try_new(schema, vec![data]).unwrap()
}
fn assert_split_batches(
batches: Vec<(RecordBatch, bool)>,
batch_size: usize,
num_rows: usize,
) {
let mut row_count = 0;
for (batch, last) in batches.into_iter() {
assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size));
let column = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
for i in 0..batch.num_rows() {
assert_eq!(column.value(i), i as i32 + row_count as i32);
}
row_count += batch.num_rows();
assert_eq!(last, row_count == num_rows);
}
}
#[rstest]
#[test]
fn test_batch_splitter(
#[values(1, 3, 11)] batch_size: usize,
#[values(1, 6, 50)] num_rows: usize,
) {
let mut splitter = BatchSplitter::new(batch_size);
splitter.set_batch(create_test_batch(num_rows));
let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size));
while let Some(batch) = splitter.next() {
batches.push(batch);
}
assert!(splitter.next().is_none());
assert_split_batches(batches, batch_size, num_rows);
}
#[tokio::test]
async fn test_swap_reverting_projection() {
let left_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]);
let proj = swap_reverting_projection(&left_schema, &right_schema);
assert_eq!(proj.len(), 3);
let proj_expr = &proj[0];
assert_eq!(proj_expr.alias, "a");
assert_col_expr(&proj_expr.expr, "a", 1);
let proj_expr = &proj[1];
assert_eq!(proj_expr.alias, "b");
assert_col_expr(&proj_expr.expr, "b", 2);
let proj_expr = &proj[2];
assert_eq!(proj_expr.alias, "c");
assert_col_expr(&proj_expr.expr, "c", 0);
}
fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
let col = expr
.as_any()
.downcast_ref::<Column>()
.expect("Projection items should be Column expression");
assert_eq!(col.name(), name);
assert_eq!(col.index(), index);
}
#[test]
fn test_join_metadata() -> Result<()> {
let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)])
.with_metadata(HashMap::from([("key".to_string(), "left".to_string())]));
let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)])
.with_metadata(HashMap::from([("key".to_string(), "right".to_string())]));
let (join_schema, _) =
build_join_schema(&left_schema, &right_schema, &JoinType::Left);
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "left".to_string())])
);
let (join_schema, _) =
build_join_schema(&left_schema, &right_schema, &JoinType::Right);
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "right".to_string())])
);
Ok(())
}
#[test]
fn test_build_batch_empty_build_side_empty_schema() -> Result<()> {
let empty_schema = Schema::empty();
let build_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
vec![Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3]))],
)?;
let probe_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])),
vec![Arc::new(arrow::array::Int32Array::from(vec![4, 5, 6, 7]))],
)?;
let result = build_batch_empty_build_side(
&empty_schema,
&build_batch,
&probe_batch,
&[], JoinType::Right,
)?;
assert_eq!(result.num_rows(), 4);
assert_eq!(result.num_columns(), 0);
Ok(())
}
}