use arrow::array::{
new_null_array, Array, BooleanBufferBuilder, PrimitiveArray, UInt32Array,
UInt32Builder, UInt64Array,
};
use arrow::compute;
use arrow::datatypes::{Field, Schema, UInt32Type, UInt64Type};
use arrow::record_batch::{RecordBatch, RecordBatchOptions};
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
use parking_lot::Mutex;
use std::cmp::max;
use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::usize;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::{ScalarValue, SharedResult};
use datafusion_physical_expr::rewrite::TreeNodeRewritable;
use datafusion_physical_expr::{EquivalentClass, PhysicalExpr};
use crate::error::{DataFusionError, Result};
use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
use crate::physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder};
use crate::physical_plan::SchemaRef;
use crate::physical_plan::{
ColumnStatistics, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
};
pub type JoinOn = Vec<(Column, Column)>;
pub type JoinOnRef<'a> = &'a [(Column, Column)];
pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
let left: HashSet<Column> = left
.fields()
.iter()
.enumerate()
.map(|(idx, f)| Column::new(f.name(), idx))
.collect();
let right: HashSet<Column> = right
.fields()
.iter()
.enumerate()
.map(|(idx, f)| Column::new(f.name(), idx))
.collect();
check_join_set_is_valid(&left, &right, on)
}
fn check_join_set_is_valid(
left: &HashSet<Column>,
right: &HashSet<Column>,
on: &[(Column, Column)],
) -> Result<()> {
let on_left = &on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>();
let left_missing = on_left.difference(left).collect::<HashSet<_>>();
let on_right = &on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>();
let right_missing = on_right.difference(right).collect::<HashSet<_>>();
if !left_missing.is_empty() | !right_missing.is_empty() {
return Err(DataFusionError::Plan(format!(
"The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}",
)));
};
Ok(())
}
pub fn partitioned_join_output_partitioning(
join_type: JoinType,
left_partitioning: Partitioning,
right_partitioning: Partitioning,
left_columns_len: usize,
) -> Partitioning {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
left_partitioning
}
JoinType::RightSemi | JoinType::RightAnti => right_partitioning,
JoinType::Right => {
adjust_right_output_partitioning(right_partitioning, left_columns_len)
}
JoinType::Full => {
Partitioning::UnknownPartitioning(right_partitioning.partition_count())
}
}
}
pub fn adjust_right_output_partitioning(
right_partitioning: Partitioning,
left_columns_len: usize,
) -> Partitioning {
match right_partitioning {
Partitioning::RoundRobinBatch(size) => Partitioning::RoundRobinBatch(size),
Partitioning::UnknownPartitioning(size) => {
Partitioning::UnknownPartitioning(size)
}
Partitioning::Hash(exprs, size) => {
let new_exprs = exprs
.into_iter()
.map(|expr| {
expr.transform_down(&|e| match e.as_any().downcast_ref::<Column>() {
Some(col) => Ok(Some(Arc::new(Column::new(
col.name(),
left_columns_len + col.index(),
)))),
None => Ok(None),
})
.unwrap()
})
.collect::<Vec<_>>();
Partitioning::Hash(new_exprs, size)
}
}
}
pub fn combine_join_equivalence_properties(
join_type: JoinType,
left_properties: EquivalenceProperties,
right_properties: EquivalenceProperties,
left_columns_len: usize,
on: &[(Column, Column)],
schema: SchemaRef,
) -> EquivalenceProperties {
let mut new_properties = EquivalenceProperties::new(schema);
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
new_properties.extend(left_properties.classes().to_vec());
let new_right_properties = right_properties
.classes()
.iter()
.map(|prop| {
let new_head = Column::new(
prop.head().name(),
left_columns_len + prop.head().index(),
);
let new_others = prop
.others()
.iter()
.map(|col| {
Column::new(col.name(), left_columns_len + col.index())
})
.collect::<Vec<_>>();
EquivalentClass::new(new_head, new_others)
})
.collect::<Vec<_>>();
new_properties.extend(new_right_properties);
}
JoinType::LeftSemi | JoinType::LeftAnti => {
new_properties.extend(left_properties.classes().to_vec())
}
JoinType::RightSemi | JoinType::RightAnti => {
new_properties.extend(right_properties.classes().to_vec())
}
}
if join_type == JoinType::Inner {
on.iter().for_each(|(column1, column2)| {
let new_column2 =
Column::new(column2.name(), left_columns_len + column2.index());
new_properties.add_equal_conditions((column1, &new_column2))
})
}
new_properties
}
pub fn cross_join_equivalence_properties(
left_properties: EquivalenceProperties,
right_properties: EquivalenceProperties,
left_columns_len: usize,
schema: SchemaRef,
) -> EquivalenceProperties {
let mut new_properties = EquivalenceProperties::new(schema);
new_properties.extend(left_properties.classes().to_vec());
let new_right_properties = right_properties
.classes()
.iter()
.map(|prop| {
let new_head =
Column::new(prop.head().name(), left_columns_len + prop.head().index());
let new_others = prop
.others()
.iter()
.map(|col| Column::new(col.name(), left_columns_len + col.index()))
.collect::<Vec<_>>();
EquivalentClass::new(new_head, new_others)
})
.collect::<Vec<_>>();
new_properties.extend(new_right_properties);
new_properties
}
impl Display for JoinSide {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
JoinSide::Left => write!(f, "left"),
JoinSide::Right => write!(f, "right"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JoinSide {
Left,
Right,
}
impl JoinSide {
pub fn negate(&self) -> Self {
match self {
JoinSide::Left => JoinSide::Right,
JoinSide::Right => JoinSide::Left,
}
}
}
#[derive(Debug, Clone)]
pub struct ColumnIndex {
pub index: usize,
pub side: JoinSide,
}
#[derive(Debug, Clone)]
pub struct JoinFilter {
expression: Arc<dyn PhysicalExpr>,
column_indices: Vec<ColumnIndex>,
schema: Schema,
}
impl JoinFilter {
pub fn new(
expression: Arc<dyn PhysicalExpr>,
column_indices: Vec<ColumnIndex>,
schema: Schema,
) -> JoinFilter {
JoinFilter {
expression,
column_indices,
schema,
}
}
pub fn build_column_indices(
left_indices: Vec<usize>,
right_indices: Vec<usize>,
) -> Vec<ColumnIndex> {
left_indices
.into_iter()
.map(|i| ColumnIndex {
index: i,
side: JoinSide::Left,
})
.chain(right_indices.into_iter().map(|i| ColumnIndex {
index: i,
side: JoinSide::Right,
}))
.collect()
}
pub fn expression(&self) -> &Arc<dyn PhysicalExpr> {
&self.expression
}
pub fn column_indices(&self) -> &[ColumnIndex] {
&self.column_indices
}
pub fn schema(&self) -> &Schema {
&self.schema
}
}
fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
let force_nullable = match join_type {
JoinType::Inner => false,
JoinType::Left => !is_left, JoinType::Right => is_left, JoinType::Full => true, JoinType::LeftSemi => false, JoinType::RightSemi => false, JoinType::LeftAnti => false, JoinType::RightAnti => false, };
if force_nullable {
old_field.clone().with_nullable(true)
} else {
old_field.clone()
}
}
pub fn build_join_schema(
left: &Schema,
right: &Schema,
join_type: &JoinType,
) -> (Schema, Vec<ColumnIndex>) {
let (fields, column_indices): (Vec<Field>, Vec<ColumnIndex>) = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
let left_fields = left
.fields()
.iter()
.map(|f| output_join_field(f, join_type, true))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Left,
},
)
});
let right_fields = right
.fields()
.iter()
.map(|f| output_join_field(f, join_type, false))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
});
left_fields.chain(right_fields).unzip()
}
JoinType::LeftSemi | JoinType::LeftAnti => left
.fields()
.iter()
.cloned()
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Left,
},
)
})
.unzip(),
JoinType::RightSemi | JoinType::RightAnti => right
.fields()
.iter()
.cloned()
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
})
.unzip(),
};
(Schema::new(fields), column_indices)
}
pub(crate) struct OnceAsync<T> {
fut: Mutex<Option<OnceFut<T>>>,
}
impl<T> Default for OnceAsync<T> {
fn default() -> Self {
Self {
fut: Mutex::new(None),
}
}
}
impl<T> std::fmt::Debug for OnceAsync<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "OnceAsync")
}
}
impl<T: 'static> OnceAsync<T> {
pub(crate) fn once<F, Fut>(&self, f: F) -> OnceFut<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T>> + Send + 'static,
{
self.fut
.lock()
.get_or_insert_with(|| OnceFut::new(f()))
.clone()
}
}
type OnceFutPending<T> = Shared<BoxFuture<'static, SharedResult<Arc<T>>>>;
pub(crate) struct OnceFut<T> {
state: OnceFutState<T>,
}
impl<T> Clone for OnceFut<T> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
}
}
}
#[derive(Clone, Debug, Default)]
struct PartialJoinStatistics {
pub num_rows: usize,
pub column_statistics: Vec<ColumnStatistics>,
}
pub(crate) fn estimate_join_statistics(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: &JoinType,
) -> Statistics {
let left_stats = left.statistics();
let right_stats = right.statistics();
let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on);
let (num_rows, column_statistics) = match join_stats {
Some(stats) => (Some(stats.num_rows), Some(stats.column_statistics)),
None => (None, None),
};
Statistics {
num_rows,
total_byte_size: None,
column_statistics,
is_exact: false,
}
}
fn estimate_join_cardinality(
join_type: &JoinType,
left_stats: Statistics,
right_stats: Statistics,
on: &JoinOn,
) -> Option<PartialJoinStatistics> {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
let left_num_rows = left_stats.num_rows?;
let right_num_rows = right_stats.num_rows?;
let all_left_col_stats = left_stats.column_statistics?;
let all_right_col_stats = right_stats.column_statistics?;
let (left_col_stats, right_col_stats) = on
.iter()
.map(|(left, right)| {
(
all_left_col_stats[left.index()].clone(),
all_right_col_stats[right.index()].clone(),
)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
let ij_cardinality = estimate_inner_join_cardinality(
left_num_rows,
right_num_rows,
left_col_stats,
right_col_stats,
left_stats.is_exact && right_stats.is_exact,
)?;
let cardinality = match join_type {
JoinType::Inner => ij_cardinality,
JoinType::Left => max(ij_cardinality, left_num_rows),
JoinType::Right => max(ij_cardinality, right_num_rows),
JoinType::Full => {
max(ij_cardinality, left_num_rows)
+ max(ij_cardinality, right_num_rows)
- ij_cardinality
}
_ => unreachable!(),
};
Some(PartialJoinStatistics {
num_rows: cardinality,
column_statistics: all_left_col_stats
.into_iter()
.chain(all_right_col_stats.into_iter())
.collect(),
})
}
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti => None,
}
}
fn estimate_inner_join_cardinality(
left_num_rows: usize,
right_num_rows: usize,
left_col_stats: Vec<ColumnStatistics>,
right_col_stats: Vec<ColumnStatistics>,
is_exact: bool,
) -> Option<usize> {
let mut join_selectivity = None;
for (left_stat, right_stat) in left_col_stats.iter().zip(right_col_stats.iter()) {
if (left_stat.min_value.clone()? > right_stat.max_value.clone()?)
|| (left_stat.max_value.clone()? < right_stat.min_value.clone()?)
{
return if is_exact { Some(0) } else { None };
}
let left_max_distinct = max_distinct_count(left_num_rows, left_stat.clone());
let right_max_distinct = max_distinct_count(right_num_rows, right_stat.clone());
let max_distinct = max(left_max_distinct, right_max_distinct);
if max_distinct > join_selectivity {
join_selectivity = max_distinct;
}
}
match join_selectivity {
Some(selectivity) if selectivity > 0 => {
Some((left_num_rows * right_num_rows) / selectivity)
}
_ => None,
}
}
fn max_distinct_count(num_rows: usize, stats: ColumnStatistics) -> Option<usize> {
match (stats.distinct_count, stats.max_value, stats.min_value) {
(Some(_), _, _) => stats.distinct_count,
(_, Some(max), Some(min)) => {
let numeric_range = get_int_range(min, max)?;
let ceiling = num_rows - stats.null_count.unwrap_or(0);
Some(numeric_range.min(ceiling))
}
_ => None,
}
}
fn get_int_range(min: ScalarValue, max: ScalarValue) -> Option<usize> {
let delta = &max.sub(&min).ok()?;
match delta {
ScalarValue::Int8(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int16(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int32(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int64(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::UInt8(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt16(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt32(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt64(Some(delta)) => Some(*delta as usize),
_ => None,
}
.map(|open_ended_range| open_ended_range + 1)
}
enum OnceFutState<T> {
Pending(OnceFutPending<T>),
Ready(SharedResult<Arc<T>>),
}
impl<T> Clone for OnceFutState<T> {
fn clone(&self) -> Self {
match self {
Self::Pending(p) => Self::Pending(p.clone()),
Self::Ready(r) => Self::Ready(r.clone()),
}
}
}
impl<T: 'static> OnceFut<T> {
pub(crate) fn new<Fut>(fut: Fut) -> Self
where
Fut: Future<Output = Result<T>> + Send + 'static,
{
Self {
state: OnceFutState::Pending(
fut.map(|res| res.map(Arc::new).map_err(Arc::new))
.boxed()
.shared(),
),
}
}
pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll<Result<&T>> {
if let OnceFutState::Pending(fut) = &mut self.state {
let r = ready!(fut.poll_unpin(cx));
self.state = OnceFutState::Ready(r);
}
match &self.state {
OnceFutState::Pending(_) => unreachable!(),
OnceFutState::Ready(r) => Poll::Ready(
r.as_ref()
.map(|r| r.as_ref())
.map_err(|e| DataFusionError::External(Box::new(e.clone()))),
),
}
}
}
pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
matches!(
join_type,
JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full
)
}
pub(crate) fn get_final_indices_from_bit_map(
left_bit_map: &BooleanBufferBuilder,
join_type: JoinType,
) -> (UInt64Array, UInt32Array) {
let left_size = left_bit_map.len();
let left_indices = if join_type == JoinType::LeftSemi {
(0..left_size)
.filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
} else {
(0..left_size)
.filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
};
let mut builder = UInt32Builder::with_capacity(left_indices.len());
builder.append_nulls(left_indices.len());
let right_indices = builder.finish();
(left_indices, right_indices)
}
pub(crate) fn apply_join_filter_to_indices(
build_input_buffer: &RecordBatch,
probe_batch: &RecordBatch,
build_indices: UInt64Array,
probe_indices: UInt32Array,
filter: &JoinFilter,
build_side: JoinSide,
) -> Result<(UInt64Array, UInt32Array)> {
if build_indices.is_empty() && probe_indices.is_empty() {
return Ok((build_indices, probe_indices));
};
let intermediate_batch = build_batch_from_indices(
filter.schema(),
build_input_buffer,
probe_batch,
PrimitiveArray::from(build_indices.data().clone()),
PrimitiveArray::from(probe_indices.data().clone()),
filter.column_indices(),
build_side,
)?;
let filter_result = filter
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows());
let mask = as_boolean_array(&filter_result)?;
let left_filtered = PrimitiveArray::<UInt64Type>::from(
compute::filter(&build_indices, mask)?.data().clone(),
);
let right_filtered = PrimitiveArray::<UInt32Type>::from(
compute::filter(&probe_indices, mask)?.data().clone(),
);
Ok((left_filtered, right_filtered))
}
pub(crate) fn build_batch_from_indices(
schema: &Schema,
build_input_buffer: &RecordBatch,
probe_batch: &RecordBatch,
build_indices: UInt64Array,
probe_indices: UInt32Array,
column_indices: &[ColumnIndex],
build_side: JoinSide,
) -> Result<RecordBatch> {
if schema.fields().is_empty() {
let options = RecordBatchOptions::new()
.with_match_field_names(true)
.with_row_count(Some(build_indices.len()));
return Ok(RecordBatch::try_new_with_options(
Arc::new(schema.clone()),
vec![],
&options,
)?);
}
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for column_index in column_indices {
let array = if column_index.side == build_side {
let array = build_input_buffer.column(column_index.index);
if array.is_empty() || build_indices.null_count() == build_indices.len() {
assert_eq!(build_indices.null_count(), build_indices.len());
new_null_array(array.data_type(), build_indices.len())
} else {
compute::take(array.as_ref(), &build_indices, None)?
}
} else {
let array = probe_batch.column(column_index.index);
if array.is_empty() || probe_indices.null_count() == probe_indices.len() {
assert_eq!(probe_indices.null_count(), probe_indices.len());
new_null_array(array.data_type(), probe_indices.len())
} else {
compute::take(array.as_ref(), &probe_indices, None)?
}
};
columns.push(array);
}
Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
}
pub(crate) fn adjust_indices_by_join_type(
left_indices: UInt64Array,
right_indices: UInt32Array,
count_right_batch: usize,
join_type: JoinType,
) -> (UInt64Array, UInt32Array) {
match join_type {
JoinType::Inner => {
(left_indices, right_indices)
}
JoinType::Left => {
(left_indices, right_indices)
}
JoinType::Right | JoinType::Full => {
let right_unmatched_indices =
get_anti_indices(count_right_batch, &right_indices);
append_right_indices(left_indices, right_indices, right_unmatched_indices)
}
JoinType::RightSemi => {
let right_indices = get_semi_indices(count_right_batch, &right_indices);
(left_indices, right_indices)
}
JoinType::RightAnti => {
let right_indices = get_anti_indices(count_right_batch, &right_indices);
(left_indices, right_indices)
}
JoinType::LeftSemi | JoinType::LeftAnti => {
(
UInt64Array::from_iter_values(vec![]),
UInt32Array::from_iter_values(vec![]),
)
}
}
}
pub(crate) fn append_right_indices(
left_indices: UInt64Array,
right_indices: UInt32Array,
right_unmatched_indices: UInt32Array,
) -> (UInt64Array, UInt32Array) {
if right_unmatched_indices.is_empty() {
(left_indices, right_indices)
} else {
let unmatched_size = right_unmatched_indices.len();
let new_left_indices = left_indices
.iter()
.chain(std::iter::repeat(None).take(unmatched_size))
.collect::<UInt64Array>();
let new_right_indices = right_indices
.iter()
.chain(right_unmatched_indices.iter())
.collect::<UInt32Array>();
(new_left_indices, new_right_indices)
}
}
pub(crate) fn get_anti_indices(
row_count: usize,
input_indices: &UInt32Array,
) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});
(0..row_count)
.filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32))
.collect::<UInt32Array>()
}
pub(crate) fn get_anti_u64_indices(
row_count: usize,
input_indices: &UInt64Array,
) -> UInt64Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});
(0..row_count)
.filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
}
pub(crate) fn get_semi_indices(
row_count: usize,
input_indices: &UInt32Array,
) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});
(0..row_count)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32))
.collect::<UInt32Array>()
}
pub(crate) fn get_semi_u64_indices(
row_count: usize,
input_indices: &UInt64Array,
) -> UInt64Array {
let mut bitmap = BooleanBufferBuilder::new(row_count);
bitmap.append_n(row_count, false);
input_indices.iter().flatten().for_each(|v| {
bitmap.set_bit(v as usize, true);
});
(0..row_count)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u64))
.collect::<UInt64Array>()
}
#[derive(Clone, Debug)]
pub(crate) struct BuildProbeJoinMetrics {
pub(crate) build_time: metrics::Time,
pub(crate) build_input_batches: metrics::Count,
pub(crate) build_input_rows: metrics::Count,
pub(crate) build_mem_used: metrics::Gauge,
pub(crate) join_time: metrics::Time,
pub(crate) input_batches: metrics::Count,
pub(crate) input_rows: metrics::Count,
pub(crate) output_batches: metrics::Count,
pub(crate) output_rows: metrics::Count,
}
impl BuildProbeJoinMetrics {
pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition);
let build_input_batches =
MetricBuilder::new(metrics).counter("build_input_batches", partition);
let build_input_rows =
MetricBuilder::new(metrics).counter("build_input_rows", partition);
let build_mem_used =
MetricBuilder::new(metrics).gauge("build_mem_used", partition);
let input_batches =
MetricBuilder::new(metrics).counter("input_batches", partition);
let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
let output_batches =
MetricBuilder::new(metrics).counter("output_batches", partition);
let output_rows = MetricBuilder::new(metrics).output_rows(partition);
Self {
build_time,
build_input_batches,
build_input_rows,
build_mem_used,
join_time,
input_batches,
input_rows,
output_batches,
output_rows,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::error::Result as ArrowResult;
use arrow::{datatypes::DataType, error::ArrowError};
use datafusion_common::ScalarValue;
use std::pin::Pin;
fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
let left = left
.iter()
.map(|x| x.to_owned())
.collect::<HashSet<Column>>();
let right = right
.iter()
.map(|x| x.to_owned())
.collect::<HashSet<Column>>();
check_join_set_is_valid(&left, &right, on)
}
#[test]
fn check_valid() -> Result<()> {
let left = vec![Column::new("a", 0), Column::new("b1", 1)];
let right = vec![Column::new("a", 0), Column::new("b2", 1)];
let on = &[(Column::new("a", 0), Column::new("a", 0))];
check(&left, &right, on)?;
Ok(())
}
#[test]
fn check_not_in_right() {
let left = vec![Column::new("a", 0), Column::new("b", 1)];
let right = vec![Column::new("b", 0)];
let on = &[(Column::new("a", 0), Column::new("a", 0))];
assert!(check(&left, &right, on).is_err());
}
#[tokio::test]
async fn check_error_nesting() {
let once_fut = OnceFut::<()>::new(async {
Err(DataFusionError::ArrowError(ArrowError::CsvError(
"some error".to_string(),
)))
});
struct TestFut(OnceFut<()>);
impl Future for TestFut {
type Output = ArrowResult<()>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
match ready!(self.0.get(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e.into())),
}
}
}
let res = TestFut(once_fut).await;
let arrow_err_from_fut = res.expect_err("once_fut always return error");
let wrapped_err = DataFusionError::from(arrow_err_from_fut);
let root_err = wrapped_err.find_root();
assert!(matches!(
root_err,
DataFusionError::ArrowError(ArrowError::CsvError(_))
))
}
#[test]
fn check_not_in_left() {
let left = vec![Column::new("b", 0)];
let right = vec![Column::new("a", 0)];
let on = &[(Column::new("a", 0), Column::new("a", 0))];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_collision() {
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("a", 0), Column::new("b", 1)];
let on = &[(Column::new("a", 0), Column::new("b", 1))];
assert!(check(&left, &right, on).is_ok());
}
#[test]
fn check_in_right() {
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("b", 0)];
let on = &[(Column::new("a", 0), Column::new("b", 0))];
assert!(check(&left, &right, on).is_ok());
}
#[test]
fn test_join_schema() -> Result<()> {
let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
let cases = vec![
(&a, &b, JoinType::Inner, &a, &b),
(&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
(&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
(&a, &b, JoinType::Left, &a, &b_nulls),
(&a, &b_nulls, JoinType::Left, &a, &b_nulls),
(&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
(&a, &b, JoinType::Right, &a_nulls, &b),
(&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Right, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
];
for (left_in, right_in, join_type, left_out, right_out) in cases {
let (schema, _) = build_join_schema(left_in, right_in, &join_type);
let expected_fields = left_out
.fields()
.iter()
.cloned()
.chain(right_out.fields().iter().cloned())
.collect();
let expected_schema = Schema::new(expected_fields);
assert_eq!(
schema,
expected_schema,
"Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
left_in.fields()[0].name(),
left_in.fields()[0].is_nullable(),
right_in.fields()[0].name(),
right_in.fields()[0].is_nullable(),
join_type
);
}
Ok(())
}
fn create_stats(
num_rows: Option<usize>,
column_stats: Option<Vec<ColumnStatistics>>,
is_exact: bool,
) -> Statistics {
Statistics {
num_rows,
column_statistics: column_stats,
is_exact,
..Default::default()
}
}
fn create_column_stats(
min: Option<i64>,
max: Option<i64>,
distinct_count: Option<usize>,
) -> ColumnStatistics {
ColumnStatistics {
distinct_count,
min_value: min.map(|size| ScalarValue::Int64(Some(size))),
max_value: max.map(|size| ScalarValue::Int64(Some(size))),
..Default::default()
}
}
type PartialStats = (usize, Option<i64>, Option<i64>, Option<usize>);
#[test]
fn test_inner_join_cardinality_single_column() -> Result<()> {
let cases: Vec<(PartialStats, PartialStats, Option<usize>)> = vec![
(
(10, Some(1), Some(10), None),
(10, Some(1), Some(10), None),
Some(10),
),
(
(10, Some(6), Some(10), None),
(10, Some(8), Some(10), None),
Some(20),
),
(
(10, Some(8), Some(10), None),
(10, Some(6), Some(10), None),
Some(20),
),
(
(10, Some(1), Some(15), None),
(20, Some(1), Some(40), None),
Some(10),
),
(
(10, Some(1), Some(10), Some(10)),
(10, Some(1), Some(10), Some(10)),
Some(10),
),
(
(10, Some(1), Some(10), Some(5)),
(10, Some(1), Some(10), Some(2)),
Some(20),
),
(
(10, Some(1), Some(10), Some(2)),
(10, Some(1), Some(10), Some(5)),
Some(20),
),
(
(10, Some(-5), Some(5), None),
(10, Some(1), Some(5), None),
Some(10),
),
(
(10, Some(-25), Some(-20), None),
(10, Some(-25), Some(-15), None),
Some(10),
),
(
(10, Some(10), Some(0), None),
(10, Some(0), Some(10), Some(5)),
Some(20), ),
(
(10, Some(1), Some(1), None),
(10, Some(1), Some(1), None),
Some(100),
),
((10, None, None, None), (10, None, None, None), None),
((10, None, None, Some(3)), (10, None, None, Some(3)), None),
(
(10, Some(2), None, Some(3)),
(10, None, Some(5), Some(3)),
None,
),
(
(10, None, Some(3), Some(3)),
(10, Some(1), None, Some(3)),
None,
),
((10, None, Some(3), None), (10, Some(1), None, None), None),
(
(10, Some(0), Some(10), None),
(10, Some(11), Some(20), None),
None,
),
(
(10, Some(11), Some(20), None),
(10, Some(0), Some(10), None),
None,
),
(
(10, Some(5), Some(10), Some(10)),
(10, Some(11), Some(3), Some(10)),
None,
),
(
(10, Some(10), Some(5), Some(10)),
(10, Some(3), Some(7), Some(10)),
None,
),
(
(10, Some(1), Some(10), Some(0)),
(10, Some(1), Some(10), Some(0)),
None,
),
];
for (left_info, right_info, expected_cardinality) in cases {
let left_num_rows = left_info.0;
let left_col_stats =
vec![create_column_stats(left_info.1, left_info.2, left_info.3)];
let right_num_rows = right_info.0;
let right_col_stats = vec![create_column_stats(
right_info.1,
right_info.2,
right_info.3,
)];
assert_eq!(
estimate_inner_join_cardinality(
left_num_rows,
right_num_rows,
left_col_stats.clone(),
right_col_stats.clone(),
false,
),
expected_cardinality
);
let join_type = JoinType::Inner;
let join_on = vec![(Column::new("a", 0), Column::new("b", 0))];
let partial_join_stats = estimate_join_cardinality(
&join_type,
create_stats(Some(left_num_rows), Some(left_col_stats.clone()), false),
create_stats(Some(right_num_rows), Some(right_col_stats.clone()), false),
&join_on,
);
assert_eq!(
partial_join_stats.clone().map(|s| s.num_rows),
expected_cardinality
);
assert_eq!(
partial_join_stats.map(|s| s.column_statistics),
expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat())
);
}
Ok(())
}
#[test]
fn test_inner_join_cardinality_multiple_column() -> Result<()> {
let left_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(100)),
create_column_stats(Some(100), Some(500), Some(150)),
];
let right_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(50)),
create_column_stats(Some(100), Some(500), Some(200)),
];
assert_eq!(
estimate_inner_join_cardinality(
400,
400,
left_col_stats,
right_col_stats,
false
),
Some((400 * 400) / 200)
);
Ok(())
}
#[test]
fn test_inner_join_cardinality_decimal_range() -> Result<()> {
let left_col_stats = vec![ColumnStatistics {
distinct_count: None,
min_value: Some(ScalarValue::Decimal128(Some(32500), 14, 4)),
max_value: Some(ScalarValue::Decimal128(Some(35000), 14, 4)),
..Default::default()
}];
let right_col_stats = vec![ColumnStatistics {
distinct_count: None,
min_value: Some(ScalarValue::Decimal128(Some(33500), 14, 4)),
max_value: Some(ScalarValue::Decimal128(Some(34000), 14, 4)),
..Default::default()
}];
assert_eq!(
estimate_inner_join_cardinality(
100,
100,
left_col_stats,
right_col_stats,
false
),
None
);
Ok(())
}
#[test]
fn test_join_cardinality() -> Result<()> {
let cases = vec![
(JoinType::Inner, 800),
(JoinType::Left, 1000),
(JoinType::Right, 2000),
(JoinType::Full, 2200),
];
let left_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(100)),
create_column_stats(Some(0), Some(500), Some(500)),
create_column_stats(Some(1000), Some(10000), None),
];
let right_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(50)),
create_column_stats(Some(0), Some(2000), Some(2500)),
create_column_stats(Some(0), Some(100), None),
];
for (join_type, expected_num_rows) in cases {
let join_on = vec![
(Column::new("a", 0), Column::new("c", 0)),
(Column::new("b", 1), Column::new("d", 1)),
];
let partial_join_stats = estimate_join_cardinality(
&join_type,
create_stats(Some(1000), Some(left_col_stats.clone()), false),
create_stats(Some(2000), Some(right_col_stats.clone()), false),
&join_on,
)
.unwrap();
assert_eq!(partial_join_stats.num_rows, expected_num_rows);
assert_eq!(
partial_join_stats.column_statistics,
[left_col_stats.clone(), right_col_stats.clone()].concat()
);
}
Ok(())
}
#[test]
fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
let left_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(100)),
create_column_stats(Some(0), Some(500), Some(500)),
create_column_stats(Some(1000), Some(10000), None),
];
let right_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(50)),
create_column_stats(Some(0), Some(2000), Some(2500)),
create_column_stats(Some(0), Some(100), None),
];
let join_on = vec![
(Column::new("a", 0), Column::new("c", 0)),
(Column::new("x", 2), Column::new("y", 2)),
];
let cases = vec![
(JoinType::Inner, 0),
(JoinType::Left, 1000),
(JoinType::Right, 2000),
(JoinType::Full, 3000),
];
for (join_type, expected_num_rows) in cases {
let partial_join_stats = estimate_join_cardinality(
&join_type,
create_stats(Some(1000), Some(left_col_stats.clone()), true),
create_stats(Some(2000), Some(right_col_stats.clone()), true),
&join_on,
)
.unwrap();
assert_eq!(partial_join_stats.num_rows, expected_num_rows);
assert_eq!(
partial_join_stats.column_statistics,
[left_col_stats.clone(), right_col_stats.clone()].concat()
);
}
Ok(())
}
}