use std::any::Any;
use std::fmt::{self, Debug};
use std::mem::{size_of, size_of_val};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
use crate::check_if_same_properties;
use crate::common::SharedMemoryReservation;
use crate::execution_plan::{boundedness_from_children, emission_type_from_children};
use crate::joins::stream_join_utils::{
PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
calculate_filter_expr_intervals, combine_two_batches,
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
};
use crate::joins::utils::{
BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn,
JoinOnRef, NoopBatchTransformer, StatefulStreamResult, apply_join_filter_to_indices,
build_batch_from_indices, build_join_schema, check_join_is_valid, equal_rows_arr,
symmetric_join_output_partitioning, update_hash,
};
use crate::projection::{
ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
physical_to_column_exprs, update_join_filter, update_join_on,
};
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
PlanProperties, RecordBatchStream, SendableRecordBatchStream,
joins::StreamJoinPartitionMode,
metrics::{ExecutionPlanMetricsSet, MetricsSet},
};
use arrow::array::{
ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
UInt64Array,
};
use arrow::compute::concat_batches;
use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::bisect;
use datafusion_common::{
HashSet, JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err,
plan_err,
};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
use ahash::RandomState;
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use futures::{Stream, StreamExt, ready};
use parking_lot::Mutex;
const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
#[derive(Debug, Clone)]
pub struct SymmetricHashJoinExec {
pub(crate) left: Arc<dyn ExecutionPlan>,
pub(crate) right: Arc<dyn ExecutionPlan>,
pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
pub(crate) filter: Option<JoinFilter>,
pub(crate) join_type: JoinType,
random_state: RandomState,
metrics: ExecutionPlanMetricsSet,
column_indices: Vec<ColumnIndex>,
pub(crate) null_equality: NullEquality,
pub(crate) left_sort_exprs: Option<LexOrdering>,
pub(crate) right_sort_exprs: Option<LexOrdering>,
mode: StreamJoinPartitionMode,
cache: Arc<PlanProperties>,
}
impl SymmetricHashJoinExec {
#[expect(clippy::too_many_arguments)]
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: Option<JoinFilter>,
join_type: &JoinType,
null_equality: NullEquality,
left_sort_exprs: Option<LexOrdering>,
right_sort_exprs: Option<LexOrdering>,
mode: StreamJoinPartitionMode,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
if on.is_empty() {
return plan_err!(
"On constraints in SymmetricHashJoinExec should be non-empty"
);
}
check_join_is_valid(&left_schema, &right_schema, &on)?;
let (schema, column_indices) =
build_join_schema(&left_schema, &right_schema, join_type);
let random_state = RandomState::with_seeds(0, 0, 0, 0);
let schema = Arc::new(schema);
let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?;
Ok(SymmetricHashJoinExec {
left,
right,
on,
filter,
join_type: *join_type,
random_state,
metrics: ExecutionPlanMetricsSet::new(),
column_indices,
null_equality,
left_sort_exprs,
right_sort_exprs,
mode,
cache: Arc::new(cache),
})
}
fn compute_properties(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
join_type: JoinType,
join_on: JoinOnRef,
) -> Result<PlanProperties> {
let eq_properties = join_equivalence_properties(
left.equivalence_properties().clone(),
right.equivalence_properties().clone(),
&join_type,
schema,
&[false, false],
None,
join_on,
)?;
let output_partitioning =
symmetric_join_output_partitioning(left, right, &join_type)?;
Ok(PlanProperties::new(
eq_properties,
output_partitioning,
emission_type_from_children([left, right]),
boundedness_from_children([left, right]),
))
}
pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
&self.left
}
pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
&self.right
}
pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
&self.on
}
pub fn filter(&self) -> Option<&JoinFilter> {
self.filter.as_ref()
}
pub fn join_type(&self) -> &JoinType {
&self.join_type
}
pub fn null_equality(&self) -> NullEquality {
self.null_equality
}
pub fn partition_mode(&self) -> StreamJoinPartitionMode {
self.mode
}
pub fn left_sort_exprs(&self) -> Option<&LexOrdering> {
self.left_sort_exprs.as_ref()
}
pub fn right_sort_exprs(&self) -> Option<&LexOrdering> {
self.right_sort_exprs.as_ref()
}
pub fn check_if_order_information_available(&self) -> Result<bool> {
if let Some(filter) = self.filter() {
let left = self.left();
if let Some(left_ordering) = left.output_ordering() {
let right = self.right();
if let Some(right_ordering) = right.output_ordering() {
let left_convertible = convert_sort_expr_with_filter_schema(
&JoinSide::Left,
filter,
&left.schema(),
&left_ordering[0],
)?
.is_some();
let right_convertible = convert_sort_expr_with_filter_schema(
&JoinSide::Right,
filter,
&right.schema(),
&right_ordering[0],
)?
.is_some();
return Ok(left_convertible && right_convertible);
}
}
}
Ok(false)
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
let left = children.swap_remove(0);
let right = children.swap_remove(0);
Self {
left,
right,
metrics: ExecutionPlanMetricsSet::new(),
..Self::clone(self)
}
}
}
impl DisplayAs for SymmetricHashJoinExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let display_filter = self.filter.as_ref().map_or_else(
|| "".to_string(),
|f| format!(", filter={}", f.expression()),
);
let on = self
.on
.iter()
.map(|(c1, c2)| format!("({c1}, {c2})"))
.collect::<Vec<String>>()
.join(", ");
write!(
f,
"SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
self.mode, self.join_type, on, display_filter
)
}
DisplayFormatType::TreeRender => {
let on = self
.on
.iter()
.map(|(c1, c2)| {
format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
})
.collect::<Vec<String>>()
.join(", ");
writeln!(f, "mode={:?}", self.mode)?;
if *self.join_type() != JoinType::Inner {
writeln!(f, "join_type={:?}", self.join_type)?;
}
writeln!(f, "on={on}")
}
}
}
}
impl ExecutionPlan for SymmetricHashJoinExec {
fn name(&self) -> &'static str {
"SymmetricHashJoinExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn required_input_distribution(&self) -> Vec<Distribution> {
match self.mode {
StreamJoinPartitionMode::Partitioned => {
let (left_expr, right_expr) = self
.on
.iter()
.map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
.unzip();
vec![
Distribution::HashPartitioned(left_expr),
Distribution::HashPartitioned(right_expr),
]
}
StreamJoinPartitionMode::SinglePartition => {
vec![Distribution::SinglePartition, Distribution::SinglePartition]
}
}
}
fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
vec![
self.left_sort_exprs
.as_ref()
.map(|e| OrderingRequirements::from(e.clone())),
self.right_sort_exprs
.as_ref()
.map(|e| OrderingRequirements::from(e.clone())),
]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
Ok(Arc::new(SymmetricHashJoinExec::try_new(
Arc::clone(&children[0]),
Arc::clone(&children[1]),
self.on.clone(),
self.filter.clone(),
&self.join_type,
self.null_equality,
self.left_sort_exprs.clone(),
self.right_sort_exprs.clone(),
self.mode,
)?))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let left_partitions = self.left.output_partitioning().partition_count();
let right_partitions = self.right.output_partitioning().partition_count();
assert_eq_or_internal_err!(
left_partitions,
right_partitions,
"Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
consider using RepartitionExec"
);
let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
self.left_sort_exprs(),
self.right_sort_exprs(),
&self.filter,
) {
(Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
let (left, right, graph) = prepare_sorted_exprs(
filter,
&self.left,
&self.right,
left_sort_exprs,
right_sort_exprs,
)?;
(Some(left), Some(right), Some(graph))
}
_ => (None, None, None),
};
let (on_left, on_right) = self.on.iter().cloned().unzip();
let left_side_joiner =
OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
let right_side_joiner =
OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
let left_stream = self.left.execute(partition, Arc::clone(&context))?;
let right_stream = self.right.execute(partition, Arc::clone(&context))?;
let batch_size = context.session_config().batch_size();
let enforce_batch_size_in_joins =
context.session_config().enforce_batch_size_in_joins();
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
.register(context.memory_pool()),
));
if let Some(g) = graph.as_ref() {
reservation.lock().try_grow(g.size())?;
}
if enforce_batch_size_in_joins {
Ok(Box::pin(SymmetricHashJoinStream {
left_stream,
right_stream,
schema: self.schema(),
filter: self.filter.clone(),
join_type: self.join_type,
random_state: self.random_state.clone(),
left: left_side_joiner,
right: right_side_joiner,
column_indices: self.column_indices.clone(),
metrics: StreamJoinMetrics::new(partition, &self.metrics),
graph,
left_sorted_filter_expr,
right_sorted_filter_expr,
null_equality: self.null_equality,
state: SHJStreamState::PullRight,
reservation,
batch_transformer: BatchSplitter::new(batch_size),
}))
} else {
Ok(Box::pin(SymmetricHashJoinStream {
left_stream,
right_stream,
schema: self.schema(),
filter: self.filter.clone(),
join_type: self.join_type,
random_state: self.random_state.clone(),
left: left_side_joiner,
right: right_side_joiner,
column_indices: self.column_indices.clone(),
metrics: StreamJoinMetrics::new(partition, &self.metrics),
graph,
left_sorted_filter_expr,
right_sorted_filter_expr,
null_equality: self.null_equality,
state: SHJStreamState::PullRight,
reservation,
batch_transformer: NoopBatchTransformer::new(),
}))
}
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
else {
return Ok(None);
};
let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
self.left().schema().fields().len(),
&projection_as_columns,
);
if !join_allows_pushdown(
&projection_as_columns,
&self.schema(),
far_right_left_col_ind,
far_left_right_col_ind,
) {
return Ok(None);
}
let Some(new_on) = update_join_on(
&projection_as_columns[0..=far_right_left_col_ind as _],
&projection_as_columns[far_left_right_col_ind as _..],
self.on(),
self.left().schema().fields().len(),
) else {
return Ok(None);
};
let new_filter = if let Some(filter) = self.filter() {
match update_join_filter(
&projection_as_columns[0..=far_right_left_col_ind as _],
&projection_as_columns[far_left_right_col_ind as _..],
filter,
self.left().schema().fields().len(),
) {
Some(updated_filter) => Some(updated_filter),
None => return Ok(None),
}
} else {
None
};
let (new_left, new_right) = new_join_children(
&projection_as_columns,
far_right_left_col_ind,
far_left_right_col_ind,
self.left(),
self.right(),
)?;
SymmetricHashJoinExec::try_new(
Arc::new(new_left),
Arc::new(new_right),
new_on,
new_filter,
self.join_type(),
self.null_equality(),
self.right().output_ordering().cloned(),
self.left().output_ordering().cloned(),
self.partition_mode(),
)
.map(|e| Some(Arc::new(e) as _))
}
}
struct SymmetricHashJoinStream<T> {
left_stream: SendableRecordBatchStream,
right_stream: SendableRecordBatchStream,
schema: Arc<Schema>,
filter: Option<JoinFilter>,
join_type: JoinType,
left: OneSideHashJoiner,
right: OneSideHashJoiner,
column_indices: Vec<ColumnIndex>,
graph: Option<ExprIntervalGraph>,
left_sorted_filter_expr: Option<SortedFilterExpr>,
right_sorted_filter_expr: Option<SortedFilterExpr>,
random_state: RandomState,
null_equality: NullEquality,
metrics: StreamJoinMetrics,
reservation: SharedMemoryReservation,
state: SHJStreamState,
batch_transformer: T,
}
impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
for SymmetricHashJoinStream<T>
{
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.poll_next_impl(cx)
}
}
fn determine_prune_length(
buffer: &RecordBatch,
build_side_filter_expr: &SortedFilterExpr,
) -> Result<usize> {
let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
let interval = build_side_filter_expr.interval();
let batch_arr = origin_sorted_expr
.expr
.evaluate(buffer)?
.into_array(buffer.num_rows())?;
let target = if origin_sorted_expr.options.descending {
interval.upper().clone()
} else {
interval.lower().clone()
};
bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
}
fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
if build_side == JoinSide::Left {
matches!(
join_type,
JoinType::Left
| JoinType::LeftAnti
| JoinType::Full
| JoinType::LeftSemi
| JoinType::LeftMark
)
} else {
matches!(
join_type,
JoinType::Right
| JoinType::RightAnti
| JoinType::Full
| JoinType::RightSemi
| JoinType::RightMark
)
}
}
fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
build_side: JoinSide,
prune_length: usize,
visited_rows: &HashSet<usize>,
deleted_offset: usize,
join_type: JoinType,
) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
where
NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
{
let result = match (build_side, join_type) {
(JoinSide::Left, JoinType::LeftMark) => {
let build_indices = (0..prune_length)
.map(L::Native::from_usize)
.collect::<PrimitiveArray<L>>();
let probe_indices = (0..prune_length)
.map(|idx| {
visited_rows
.contains(&(idx + deleted_offset))
.then_some(R::Native::from_usize(0).unwrap())
})
.collect();
(build_indices, probe_indices)
}
(JoinSide::Right, JoinType::RightMark) => {
let build_indices = (0..prune_length)
.map(L::Native::from_usize)
.collect::<PrimitiveArray<L>>();
let probe_indices = (0..prune_length)
.map(|idx| {
visited_rows
.contains(&(idx + deleted_offset))
.then_some(R::Native::from_usize(0).unwrap())
})
.collect();
(build_indices, probe_indices)
}
(JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
| (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
| (_, JoinType::Full) => {
let build_unmatched_indices =
get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
let probe_indices = builder.finish();
(build_unmatched_indices, probe_indices)
}
(JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
let build_unmatched_indices =
get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
let probe_indices = builder.finish();
(build_unmatched_indices, probe_indices)
}
_ => unreachable!(),
};
Ok(result)
}
pub(crate) fn build_side_determined_results(
build_hash_joiner: &OneSideHashJoiner,
output_schema: &SchemaRef,
prune_length: usize,
probe_schema: SchemaRef,
join_type: JoinType,
column_indices: &[ColumnIndex],
) -> Result<Option<RecordBatch>> {
if prune_length > 0
&& need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
{
let (build_indices, probe_indices) = calculate_indices_by_join_type(
build_hash_joiner.build_side,
prune_length,
&build_hash_joiner.visited_rows,
build_hash_joiner.deleted_offset,
join_type,
)?;
let empty_probe_batch = RecordBatch::new_empty(probe_schema);
build_batch_from_indices(
output_schema.as_ref(),
&build_hash_joiner.input_buffer,
&empty_probe_batch,
&build_indices,
&probe_indices,
column_indices,
build_hash_joiner.build_side,
join_type,
)
.map(|batch| (batch.num_rows() > 0).then_some(batch))
} else {
Ok(None)
}
}
#[expect(clippy::too_many_arguments)]
pub(crate) fn join_with_probe_batch(
build_hash_joiner: &mut OneSideHashJoiner,
probe_hash_joiner: &mut OneSideHashJoiner,
schema: &SchemaRef,
join_type: JoinType,
filter: Option<&JoinFilter>,
probe_batch: &RecordBatch,
column_indices: &[ColumnIndex],
random_state: &RandomState,
null_equality: NullEquality,
) -> Result<Option<RecordBatch>> {
if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(None);
}
let (build_indices, probe_indices) = lookup_join_hashmap(
&build_hash_joiner.hashmap,
&build_hash_joiner.input_buffer,
probe_batch,
&build_hash_joiner.on,
&probe_hash_joiner.on,
random_state,
null_equality,
&mut build_hash_joiner.hashes_buffer,
Some(build_hash_joiner.deleted_offset),
)?;
let (build_indices, probe_indices) = if let Some(filter) = filter {
apply_join_filter_to_indices(
&build_hash_joiner.input_buffer,
probe_batch,
build_indices,
probe_indices,
filter,
build_hash_joiner.build_side,
None,
join_type,
)?
} else {
(build_indices, probe_indices)
};
if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
record_visited_indices(
&mut build_hash_joiner.visited_rows,
build_hash_joiner.deleted_offset,
&build_indices,
);
}
if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
record_visited_indices(
&mut probe_hash_joiner.visited_rows,
probe_hash_joiner.offset,
&probe_indices,
);
}
if matches!(
join_type,
JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftSemi
| JoinType::LeftMark
| JoinType::RightSemi
| JoinType::RightMark
) {
Ok(None)
} else {
build_batch_from_indices(
schema,
&build_hash_joiner.input_buffer,
probe_batch,
&build_indices,
&probe_indices,
column_indices,
build_hash_joiner.build_side,
join_type,
)
.map(|batch| (batch.num_rows() > 0).then_some(batch))
}
}
#[expect(clippy::too_many_arguments)]
fn lookup_join_hashmap(
build_hashmap: &PruningJoinHashMap,
build_batch: &RecordBatch,
probe_batch: &RecordBatch,
build_on: &[PhysicalExprRef],
probe_on: &[PhysicalExprRef],
random_state: &RandomState,
null_equality: NullEquality,
hashes_buffer: &mut Vec<u64>,
deleted_offset: Option<usize>,
) -> Result<(UInt64Array, UInt32Array)> {
let keys_values = evaluate_expressions_to_arrays(probe_on, probe_batch)?;
let build_join_values = evaluate_expressions_to_arrays(build_on, build_batch)?;
hashes_buffer.clear();
hashes_buffer.resize(probe_batch.num_rows(), 0);
let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
let (mut matched_probe, mut matched_build) = build_hashmap.get_matched_indices(
Box::new(hash_values.iter().enumerate().rev()),
deleted_offset,
);
matched_probe.reverse();
matched_build.reverse();
let build_indices: UInt64Array = matched_build.into();
let probe_indices: UInt32Array = matched_probe.into();
let (build_indices, probe_indices) = equal_rows_arr(
&build_indices,
&probe_indices,
&build_join_values,
&keys_values,
null_equality,
)?;
Ok((build_indices, probe_indices))
}
pub struct OneSideHashJoiner {
build_side: JoinSide,
pub input_buffer: RecordBatch,
pub(crate) on: Vec<PhysicalExprRef>,
pub(crate) hashmap: PruningJoinHashMap,
pub(crate) hashes_buffer: Vec<u64>,
pub(crate) visited_rows: HashSet<usize>,
pub(crate) offset: usize,
pub(crate) deleted_offset: usize,
}
impl OneSideHashJoiner {
pub fn size(&self) -> usize {
let mut size = 0;
size += size_of_val(self);
size += size_of_val(&self.build_side);
size += self.input_buffer.get_array_memory_size();
size += size_of_val(&self.on);
size += self.hashmap.size();
size += self.hashes_buffer.capacity() * size_of::<u64>();
size += self.visited_rows.capacity() * size_of::<usize>();
size += size_of_val(&self.offset);
size += size_of_val(&self.deleted_offset);
size
}
pub fn new(
build_side: JoinSide,
on: Vec<PhysicalExprRef>,
schema: SchemaRef,
) -> Self {
Self {
build_side,
input_buffer: RecordBatch::new_empty(schema),
on,
hashmap: PruningJoinHashMap::with_capacity(0),
hashes_buffer: vec![],
visited_rows: HashSet::new(),
offset: 0,
deleted_offset: 0,
}
}
pub(crate) fn update_internal_state(
&mut self,
batch: &RecordBatch,
random_state: &RandomState,
) -> Result<()> {
self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
self.hashes_buffer.resize(batch.num_rows(), 0);
update_hash(
&self.on,
batch,
&mut self.hashmap,
self.offset,
random_state,
&mut self.hashes_buffer,
self.deleted_offset,
false,
)?;
Ok(())
}
pub(crate) fn calculate_prune_length_with_probe_batch(
&mut self,
build_side_sorted_filter_expr: &mut SortedFilterExpr,
probe_side_sorted_filter_expr: &mut SortedFilterExpr,
graph: &mut ExprIntervalGraph,
) -> Result<usize> {
if self.input_buffer.num_rows() == 0 {
return Ok(0);
}
let mut filter_intervals = vec![];
for expr in [
&build_side_sorted_filter_expr,
&probe_side_sorted_filter_expr,
] {
filter_intervals.push((expr.node_index(), expr.interval().clone()))
}
graph.update_ranges(&mut filter_intervals, Interval::TRUE)?;
let calculated_build_side_interval = filter_intervals.remove(0).1;
if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
return Ok(0);
}
build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
}
pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
self.hashmap.prune_hash_values(
prune_length,
self.deleted_offset as u64,
HASHMAP_SHRINK_SCALE_FACTOR,
);
for row in self.deleted_offset..(self.deleted_offset + prune_length) {
self.visited_rows.remove(&row);
}
self.input_buffer = self
.input_buffer
.slice(prune_length, self.input_buffer.num_rows() - prune_length);
self.deleted_offset += prune_length;
Ok(())
}
}
impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
fn poll_next_impl(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
match self.batch_transformer.next() {
None => {
let result = match self.state() {
SHJStreamState::PullRight => {
ready!(self.fetch_next_from_right_stream(cx))
}
SHJStreamState::PullLeft => {
ready!(self.fetch_next_from_left_stream(cx))
}
SHJStreamState::RightExhausted => {
ready!(self.handle_right_stream_end(cx))
}
SHJStreamState::LeftExhausted => {
ready!(self.handle_left_stream_end(cx))
}
SHJStreamState::BothExhausted {
final_result: false,
} => self.prepare_for_final_results_after_exhaustion(),
SHJStreamState::BothExhausted { final_result: true } => {
return Poll::Ready(None);
}
};
match result? {
StatefulStreamResult::Ready(None) => {
return Poll::Ready(None);
}
StatefulStreamResult::Ready(Some(batch)) => {
self.batch_transformer.set_batch(batch);
}
_ => {}
}
}
Some((batch, _)) => {
return self
.metrics
.baseline_metrics
.record_poll(Poll::Ready(Some(Ok(batch))));
}
}
}
}
fn fetch_next_from_right_stream(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
match ready!(self.right_stream().poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
self.set_state(SHJStreamState::PullLeft);
Poll::Ready(self.process_batch_from_right(&batch))
}
Some(Err(e)) => Poll::Ready(Err(e)),
None => {
self.set_state(SHJStreamState::RightExhausted);
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
}
}
fn fetch_next_from_left_stream(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
match ready!(self.left_stream().poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
self.set_state(SHJStreamState::PullRight);
Poll::Ready(self.process_batch_from_left(&batch))
}
Some(Err(e)) => Poll::Ready(Err(e)),
None => {
self.set_state(SHJStreamState::LeftExhausted);
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
}
}
fn handle_right_stream_end(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
match ready!(self.left_stream().poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
Poll::Ready(self.process_batch_after_right_end(&batch))
}
Some(Err(e)) => Poll::Ready(Err(e)),
None => {
self.set_state(SHJStreamState::BothExhausted {
final_result: false,
});
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
}
}
fn handle_left_stream_end(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
match ready!(self.right_stream().poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
Poll::Ready(self.process_batch_after_left_end(&batch))
}
Some(Err(e)) => Poll::Ready(Err(e)),
None => {
self.set_state(SHJStreamState::BothExhausted {
final_result: false,
});
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
}
}
fn prepare_for_final_results_after_exhaustion(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.set_state(SHJStreamState::BothExhausted { final_result: true });
self.process_batches_before_finalization()
}
fn process_batch_from_right(
&mut self,
batch: &RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Right)
.map(|maybe_batch| {
if maybe_batch.is_some() {
StatefulStreamResult::Ready(maybe_batch)
} else {
StatefulStreamResult::Continue
}
})
}
fn process_batch_from_left(
&mut self,
batch: &RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Left)
.map(|maybe_batch| {
if maybe_batch.is_some() {
StatefulStreamResult::Ready(maybe_batch)
} else {
StatefulStreamResult::Continue
}
})
}
fn process_batch_after_left_end(
&mut self,
right_batch: &RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_right(right_batch)
}
fn process_batch_after_right_end(
&mut self,
left_batch: &RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_left(left_batch)
}
fn process_batches_before_finalization(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let left_result = build_side_determined_results(
&self.left,
&self.schema,
self.left.input_buffer.num_rows(),
self.right.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
let right_result = build_side_determined_results(
&self.right,
&self.schema,
self.right.input_buffer.num_rows(),
self.left.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
let result = combine_two_batches(&self.schema, left_result, right_result)?;
if result.is_some() {
return Ok(StatefulStreamResult::Ready(result));
}
Ok(StatefulStreamResult::Continue)
}
fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
&mut self.right_stream
}
fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
&mut self.left_stream
}
fn set_state(&mut self, state: SHJStreamState) {
self.state = state;
}
fn state(&mut self) -> SHJStreamState {
self.state.clone()
}
fn size(&self) -> usize {
let mut size = 0;
size += size_of_val(&self.schema);
size += size_of_val(&self.filter);
size += size_of_val(&self.join_type);
size += self.left.size();
size += self.right.size();
size += size_of_val(&self.column_indices);
size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
size += size_of_val(&self.left_sorted_filter_expr);
size += size_of_val(&self.right_sorted_filter_expr);
size += size_of_val(&self.random_state);
size += size_of_val(&self.null_equality);
size += size_of_val(&self.metrics);
size
}
fn perform_join_for_given_side(
&mut self,
probe_batch: &RecordBatch,
probe_side: JoinSide,
) -> Result<Option<RecordBatch>> {
let (
probe_hash_joiner,
build_hash_joiner,
probe_side_sorted_filter_expr,
build_side_sorted_filter_expr,
probe_side_metrics,
) = if probe_side.eq(&JoinSide::Left) {
(
&mut self.left,
&mut self.right,
&mut self.left_sorted_filter_expr,
&mut self.right_sorted_filter_expr,
&mut self.metrics.left,
)
} else {
(
&mut self.right,
&mut self.left,
&mut self.right_sorted_filter_expr,
&mut self.left_sorted_filter_expr,
&mut self.metrics.right,
)
};
probe_side_metrics.input_batches.add(1);
probe_side_metrics.input_rows.add(probe_batch.num_rows());
probe_hash_joiner.update_internal_state(probe_batch, &self.random_state)?;
let equal_result = join_with_probe_batch(
build_hash_joiner,
probe_hash_joiner,
&self.schema,
self.join_type,
self.filter.as_ref(),
probe_batch,
&self.column_indices,
&self.random_state,
self.null_equality,
)?;
probe_hash_joiner.offset += probe_batch.num_rows();
let anti_result = if let (
Some(build_side_sorted_filter_expr),
Some(probe_side_sorted_filter_expr),
Some(graph),
) = (
build_side_sorted_filter_expr.as_mut(),
probe_side_sorted_filter_expr.as_mut(),
self.graph.as_mut(),
) {
calculate_filter_expr_intervals(
&build_hash_joiner.input_buffer,
build_side_sorted_filter_expr,
probe_batch,
probe_side_sorted_filter_expr,
)?;
let prune_length = build_hash_joiner
.calculate_prune_length_with_probe_batch(
build_side_sorted_filter_expr,
probe_side_sorted_filter_expr,
graph,
)?;
let result = build_side_determined_results(
build_hash_joiner,
&self.schema,
prune_length,
probe_batch.schema(),
self.join_type,
&self.column_indices,
)?;
build_hash_joiner.prune_internal_state(prune_length)?;
result
} else {
None
};
let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
let capacity = self.size();
self.metrics.stream_memory_usage.set(capacity);
self.reservation.lock().try_resize(capacity)?;
Ok(result)
}
}
#[derive(Clone, Debug)]
pub enum SHJStreamState {
PullRight,
PullLeft,
RightExhausted,
LeftExhausted,
BothExhausted { final_result: bool },
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
use super::*;
use crate::joins::test_utils::{
build_sides_record_batches, compare_batches, complicated_filter,
create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
partitioned_sym_join_with_filter, split_record_batches,
};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion_common::ScalarValue;
use datafusion_execution::config::SessionConfig;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{Column, binary, col, lit};
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use rstest::*;
const TABLE_SIZE: i32 = 30;
type TableKey = (i32, i32, usize); type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>);
static TABLE_CACHE: LazyLock<Mutex<HashMap<TableKey, TableValue>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
fn get_or_create_table(
cardinality: (i32, i32),
batch_size: usize,
) -> Result<TableValue> {
{
let cache = TABLE_CACHE.lock().unwrap();
if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
return Ok(table.clone());
}
}
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let (left_partition, right_partition) = (
split_record_batches(&left_batch, batch_size)?,
split_record_batches(&right_batch, batch_size)?,
);
let mut cache = TABLE_CACHE.lock().unwrap();
cache.insert(
(cardinality.0, cardinality.1, batch_size),
(left_partition.clone(), right_partition.clone()),
);
Ok((left_partition, right_partition))
}
pub async fn experiment(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
filter: Option<JoinFilter>,
join_type: JoinType,
on: JoinOn,
task_ctx: Arc<TaskContext>,
) -> Result<()> {
let first_batches = partitioned_sym_join_with_filter(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
filter.clone(),
&join_type,
NullEquality::NullEqualsNothing,
Arc::clone(&task_ctx),
)
.await?;
let second_batches = partitioned_hash_join_with_filter(
left,
right,
on,
filter,
&join_type,
NullEquality::NullEqualsNothing,
task_ctx,
)
.await?;
compare_batches(&first_batches, &second_batches);
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_numeric(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: binary(
col("la1", left_schema)?,
Operator::Plus,
col("la2", left_schema)?,
left_schema,
)?,
options: SortOptions::default(),
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
binary(
col("lc1", left_schema)?,
Operator::Plus,
lit(ScalarValue::Int32(Some(1))),
left_schema,
)?,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: left_schema.index_of("la1")?,
side: JoinSide::Left,
},
ColumnIndex {
index: left_schema.index_of("la2")?,
side: JoinSide::Left,
},
ColumnIndex {
index: right_schema.index_of("ra1")?,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_all_one_ascending_numeric(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_without_sort_information(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let (left, right) =
create_memory_table(left_partition, right_partition, vec![], vec![])?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 5,
side: JoinSide::Left,
},
ColumnIndex {
index: 5,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_without_filter(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let (left, right) =
create_memory_table(left_partition, right_partition, vec![], vec![])?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
experiment(left, right, None, join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_all_one_descending_numeric_particular(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: col("la1_des", left_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("ra1_des", right_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 5,
side: JoinSide::Left,
},
ColumnIndex {
index: 5,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_first() -> Result<()> {
let join_type = JoinType::Full;
let case_expr = 1;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: col("l_asc_null_first", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("r_asc_null_first", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 6,
side: JoinSide::Left,
},
ColumnIndex {
index: 6,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_last() -> Result<()> {
let join_type = JoinType::Full;
let case_expr = 1;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: col("l_asc_null_last", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("r_asc_null_last", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 7,
side: JoinSide::Left,
},
ColumnIndex {
index: 7,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_first_descending() -> Result<()> {
let join_type = JoinType::Full;
let cardinality = (10, 11);
let case_expr = 1;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: col("l_desc_null_first", left_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("r_desc_null_first", right_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 8,
side: JoinSide::Left,
},
ColumnIndex {
index: 8,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
let cardinality = (3, 4);
let join_type = JoinType::Full;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 4,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
let cardinality = (3, 4);
let join_type = JoinType::Full;
let config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![
[PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}]
.into(),
[PhysicalSortExpr {
expr: col("la2", left_schema)?,
options: SortOptions::default(),
}]
.into(),
];
let right_sorted = [PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
left_sorted,
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 4,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn testing_with_temporal_columns(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
#[values(0, 1, 2)] case_expr: usize,
) -> Result<()> {
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let left_sorted = [PhysicalSortExpr {
expr: col("lt1", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("rt1", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let intermediate_schema = Schema::new(vec![
Field::new(
"left",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"right",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
]);
let filter_expr = join_expr_tests_fixture_temporal(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 3,
side: JoinSide::Left,
},
ColumnIndex {
index: 3,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn test_with_interval_columns(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
) -> Result<()> {
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let left_sorted = [PhysicalSortExpr {
expr: col("li1", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("ri1", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
]);
let filter_expr = join_expr_tests_fixture_temporal(
0,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 9,
side: JoinSide::Left,
},
ColumnIndex {
index: 9,
side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn testing_ascending_float_pruning(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightAnti,
JoinType::RightMark,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = [PhysicalSortExpr {
expr: col("l_float", left_schema)?,
options: SortOptions::default(),
}]
.into();
let right_sorted = [PhysicalSortExpr {
expr: col("r_float", right_schema)?,
options: SortOptions::default(),
}]
.into();
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Float64, true),
Field::new("right", DataType::Float64, true),
]);
let filter_expr = join_expr_tests_fixture_f64(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 10, side: JoinSide::Left,
},
ColumnIndex {
index: 10, side: JoinSide::Right,
},
];
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
}