use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::task::Poll;
use std::vec;
use std::{any::Any, usize};
use ahash::RandomState;
use arrow::array::{
ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray,
PrimitiveBuilder,
};
use arrow::compute::concat_batches;
use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use futures::{Stream, StreamExt};
use hashbrown::{raw::RawTable, HashSet};
use datafusion_common::{utils::bisect, ScalarValue};
use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval};
use crate::error::{DataFusionError, Result};
use crate::execution::context::TaskContext;
use crate::logical_expr::JoinType;
use crate::physical_plan::{
expressions::Column,
expressions::PhysicalSortExpr,
joins::{
hash_join::{build_join_indices, update_hash, JoinHashMap},
hash_join_utils::{build_filter_input_order, SortedFilterExpr},
utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
combine_join_equivalence_properties, partitioned_join_output_partitioning,
ColumnIndex, JoinFilter, JoinOn, JoinSide,
},
},
metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
pub struct SymmetricHashJoinExec {
pub(crate) left: Arc<dyn ExecutionPlan>,
pub(crate) right: Arc<dyn ExecutionPlan>,
pub(crate) on: Vec<(Column, Column)>,
pub(crate) filter: JoinFilter,
pub(crate) join_type: JoinType,
sorted_filter_exprs: Vec<SortedFilterExpr>,
left_required_sort_exprs: Vec<PhysicalSortExpr>,
right_required_sort_exprs: Vec<PhysicalSortExpr>,
physical_expr_graph: ExprIntervalGraph,
schema: SchemaRef,
random_state: RandomState,
metrics: ExecutionPlanMetricsSet,
column_indices: Vec<ColumnIndex>,
pub(crate) null_equals_null: bool,
}
#[derive(Debug)]
struct SymmetricHashJoinSideMetrics {
input_batches: metrics::Count,
input_rows: metrics::Count,
}
#[derive(Debug)]
struct SymmetricHashJoinMetrics {
left: SymmetricHashJoinSideMetrics,
right: SymmetricHashJoinSideMetrics,
output_batches: metrics::Count,
output_rows: metrics::Count,
}
impl SymmetricHashJoinMetrics {
pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
let input_batches =
MetricBuilder::new(metrics).counter("input_batches", partition);
let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
let left = SymmetricHashJoinSideMetrics {
input_batches,
input_rows,
};
let input_batches =
MetricBuilder::new(metrics).counter("input_batches", partition);
let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
let right = SymmetricHashJoinSideMetrics {
input_batches,
input_rows,
};
let output_batches =
MetricBuilder::new(metrics).counter("output_batches", partition);
let output_rows = MetricBuilder::new(metrics).output_rows(partition);
Self {
left,
right,
output_batches,
output_rows,
}
}
}
impl SymmetricHashJoinExec {
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: JoinFilter,
join_type: &JoinType,
null_equals_null: bool,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
if on.is_empty() {
return Err(DataFusionError::Plan(
"On constraints in SymmetricHashJoinExec should be non-empty".to_string(),
));
}
check_join_is_valid(&left_schema, &right_schema, &on)?;
let (schema, column_indices) =
build_join_schema(&left_schema, &right_schema, join_type);
let random_state = RandomState::with_seeds(0, 0, 0, 0);
let mut physical_expr_graph =
ExprIntervalGraph::try_new(filter.expression().clone())?;
let (left_ordering, right_ordering) = match (
left.output_ordering(),
right.output_ordering(),
) {
(Some([left_ordering, ..]), Some([right_ordering, ..])) => {
(left_ordering, right_ordering)
}
_ => {
return Err(DataFusionError::Plan(
"Symmetric hash join requires its children to have an output ordering".to_string(),
));
}
};
let left_filter_expression = build_filter_input_order(
JoinSide::Left,
&filter,
&left.schema(),
left_ordering,
)?;
let right_filter_expression = build_filter_input_order(
JoinSide::Right,
&filter,
&right.schema(),
right_ordering,
)?;
let mut sorted_filter_exprs =
vec![left_filter_expression, right_filter_expression];
let child_node_indexes = physical_expr_graph.gather_node_indices(
&sorted_filter_exprs
.iter()
.map(|sorted_expr| sorted_expr.filter_expr().clone())
.collect::<Vec<_>>(),
);
for (sorted_expr, (_, index)) in sorted_filter_exprs
.iter_mut()
.zip(child_node_indexes.iter())
{
sorted_expr.set_node_index(*index);
}
let left_required_sort_exprs = vec![left_ordering.clone()];
let right_required_sort_exprs = vec![right_ordering.clone()];
Ok(SymmetricHashJoinExec {
left,
right,
on,
filter,
join_type: *join_type,
sorted_filter_exprs,
left_required_sort_exprs,
right_required_sort_exprs,
physical_expr_graph,
schema: Arc::new(schema),
random_state,
metrics: ExecutionPlanMetricsSet::new(),
column_indices,
null_equals_null,
})
}
pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
&self.left
}
pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
&self.right
}
pub fn on(&self) -> &[(Column, Column)] {
&self.on
}
pub fn filter(&self) -> &JoinFilter {
&self.filter
}
pub fn join_type(&self) -> &JoinType {
&self.join_type
}
pub fn null_equals_null(&self) -> bool {
self.null_equals_null
}
}
impl Debug for SymmetricHashJoinExec {
fn fmt(&self, _f: &mut Formatter<'_>) -> fmt::Result {
todo!()
}
}
impl ExecutionPlan for SymmetricHashJoinExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn required_input_ordering(&self) -> Vec<Option<&[PhysicalSortExpr]>> {
vec![
Some(&self.left_required_sort_exprs),
Some(&self.right_required_sort_exprs),
]
}
fn unbounded_output(&self, children: &[bool]) -> Result<bool> {
Ok(children.iter().any(|u| *u))
}
fn required_input_distribution(&self) -> Vec<Distribution> {
let (left_expr, right_expr) = self
.on
.iter()
.map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
.unzip();
vec![
if self.left.output_partitioning().partition_count() == 1 {
Distribution::SinglePartition
} else {
Distribution::HashPartitioned(left_expr)
},
if self.right.output_partitioning().partition_count() == 1 {
Distribution::SinglePartition
} else {
Distribution::HashPartitioned(right_expr)
},
]
}
fn output_partitioning(&self) -> Partitioning {
let left_columns_len = self.left.schema().fields.len();
partitioned_join_output_partitioning(
self.join_type,
self.left.output_partitioning(),
self.right.output_partitioning(),
left_columns_len,
)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn equivalence_properties(&self) -> EquivalenceProperties {
let left_columns_len = self.left.schema().fields.len();
combine_join_equivalence_properties(
self.join_type,
self.left.equivalence_properties(),
self.right.equivalence_properties(),
left_columns_len,
self.on(),
self.schema(),
)
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.left.clone(), self.right.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(SymmetricHashJoinExec::try_new(
children[0].clone(),
children[1].clone(),
self.on.clone(),
self.filter.clone(),
&self.join_type,
self.null_equals_null,
)?))
}
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default => {
let display_filter = format!(", filter={:?}", self.filter.expression());
write!(
f,
"SymmetricHashJoinExec: join_type={:?}, on={:?}{}",
self.join_type, self.on, display_filter
)
}
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Statistics {
Statistics::default()
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let on_left = self.on.iter().map(|on| on.0.clone()).collect::<Vec<_>>();
let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
let left_side_joiner = OneSideHashJoiner::new(
JoinSide::Left,
self.sorted_filter_exprs[0].clone(),
on_left,
self.left.schema(),
);
let right_side_joiner = OneSideHashJoiner::new(
JoinSide::Right,
self.sorted_filter_exprs[1].clone(),
on_right,
self.right.schema(),
);
let left_stream = self.left.execute(partition, context.clone())?;
let right_stream = self.right.execute(partition, context)?;
Ok(Box::pin(SymmetricHashJoinStream {
left_stream,
right_stream,
schema: self.schema(),
filter: self.filter.clone(),
join_type: self.join_type,
random_state: self.random_state.clone(),
left: left_side_joiner,
right: right_side_joiner,
column_indices: self.column_indices.clone(),
metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics),
physical_expr_graph: self.physical_expr_graph.clone(),
null_equals_null: self.null_equals_null,
final_result: false,
probe_side: JoinSide::Left,
}))
}
}
struct SymmetricHashJoinStream {
left_stream: SendableRecordBatchStream,
right_stream: SendableRecordBatchStream,
schema: Arc<Schema>,
filter: JoinFilter,
join_type: JoinType,
left: OneSideHashJoiner,
right: OneSideHashJoiner,
column_indices: Vec<ColumnIndex>,
physical_expr_graph: ExprIntervalGraph,
random_state: RandomState,
null_equals_null: bool,
metrics: SymmetricHashJoinMetrics,
final_result: bool,
probe_side: JoinSide,
}
impl RecordBatchStream for SymmetricHashJoinStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
impl Stream for SymmetricHashJoinStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.poll_next_impl(cx)
}
}
fn prune_hash_values(
prune_length: usize,
hashmap: &mut JoinHashMap,
row_hash_values: &mut VecDeque<u64>,
offset: u64,
) -> Result<()> {
let mut hash_value_map: HashMap<u64, HashSet<u64>> = HashMap::new();
for index in 0..prune_length {
let hash_value = row_hash_values.pop_front().unwrap();
if let Some(set) = hash_value_map.get_mut(&hash_value) {
set.insert(offset + index as u64);
} else {
let mut set = HashSet::new();
set.insert(offset + index as u64);
hash_value_map.insert(hash_value, set);
}
}
for (hash_value, index_set) in hash_value_map.iter() {
if let Some((_, separation_chain)) = hashmap
.0
.get_mut(*hash_value, |(hash, _)| hash_value == hash)
{
separation_chain.retain(|n| !index_set.contains(n));
if separation_chain.is_empty() {
hashmap
.0
.remove_entry(*hash_value, |(hash, _)| hash_value == hash);
}
}
}
Ok(())
}
fn calculate_filter_expr_intervals(
build_input_buffer: &RecordBatch,
build_sorted_filter_expr: &mut SortedFilterExpr,
probe_batch: &RecordBatch,
probe_sorted_filter_expr: &mut SortedFilterExpr,
) -> Result<()> {
if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(());
}
let build_array = build_sorted_filter_expr
.origin_sorted_expr()
.expr
.evaluate(&build_input_buffer.slice(0, 1))?
.into_array(1);
let probe_array = probe_sorted_filter_expr
.origin_sorted_expr()
.expr
.evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1))?
.into_array(1);
for (array, sorted_expr) in vec![
(build_array, build_sorted_filter_expr),
(probe_array, probe_sorted_filter_expr),
] {
let value = ScalarValue::try_from_array(&array, 0)?;
let infinite = ScalarValue::try_from(value.get_datatype())?;
sorted_expr.set_interval(
if sorted_expr.origin_sorted_expr().options.descending {
Interval {
lower: infinite,
upper: value,
}
} else {
Interval {
lower: value,
upper: infinite,
}
},
);
}
Ok(())
}
fn determine_prune_length(
buffer: &RecordBatch,
build_side_filter_expr: &SortedFilterExpr,
) -> Result<usize> {
let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
let interval = build_side_filter_expr.interval();
let batch_arr = origin_sorted_expr
.expr
.evaluate(buffer)?
.into_array(buffer.num_rows());
let target = if origin_sorted_expr.options.descending {
interval.upper.clone()
} else {
interval.lower.clone()
};
bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
}
fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
if build_side == JoinSide::Left {
matches!(
join_type,
JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi
)
} else {
matches!(
join_type,
JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi
)
}
}
fn get_anti_indices<T: ArrowPrimitiveType>(
prune_length: usize,
deleted_offset: usize,
visited_rows: &HashSet<usize>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let mut bitmap = BooleanBufferBuilder::new(prune_length);
bitmap.append_n(prune_length, false);
for v in 0..prune_length {
let row = v + deleted_offset;
bitmap.set_bit(v, visited_rows.contains(&row));
}
(0..prune_length)
.filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
.collect()
}
fn get_semi_indices<T: ArrowPrimitiveType>(
prune_length: usize,
deleted_offset: usize,
visited_rows: &HashSet<usize>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let mut bitmap = BooleanBufferBuilder::new(prune_length);
bitmap.append_n(prune_length, false);
(0..prune_length).for_each(|v| {
let row = &(v + deleted_offset);
bitmap.set_bit(v, visited_rows.contains(row));
});
(0..prune_length)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
.collect::<PrimitiveArray<T>>()
}
fn record_visited_indices<T: ArrowPrimitiveType>(
visited: &mut HashSet<usize>,
offset: usize,
indices: &PrimitiveArray<T>,
) {
for i in indices.values() {
visited.insert(i.as_usize() + offset);
}
}
fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
build_side: JoinSide,
prune_length: usize,
visited_rows: &HashSet<usize>,
deleted_offset: usize,
join_type: JoinType,
) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
where
NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
{
let result = match (build_side, join_type) {
(JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
| (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
| (_, JoinType::Full) => {
let build_unmatched_indices =
get_anti_indices(prune_length, deleted_offset, visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
let probe_indices = builder.finish();
(build_unmatched_indices, probe_indices)
}
(JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
let build_unmatched_indices =
get_semi_indices(prune_length, deleted_offset, visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
let probe_indices = builder.finish();
(build_unmatched_indices, probe_indices)
}
_ => unreachable!(),
};
Ok(result)
}
struct OneSideHashJoiner {
build_side: JoinSide,
sorted_filter_expr: SortedFilterExpr,
input_buffer: RecordBatch,
on: Vec<Column>,
hashmap: JoinHashMap,
row_hash_values: VecDeque<u64>,
hashes_buffer: Vec<u64>,
visited_rows: HashSet<usize>,
offset: usize,
deleted_offset: usize,
exhausted: bool,
}
impl OneSideHashJoiner {
pub fn new(
build_side: JoinSide,
sorted_filter_expr: SortedFilterExpr,
on: Vec<Column>,
schema: SchemaRef,
) -> Self {
Self {
build_side,
input_buffer: RecordBatch::new_empty(schema),
on,
hashmap: JoinHashMap(RawTable::with_capacity(10_000)),
row_hash_values: VecDeque::new(),
hashes_buffer: vec![],
sorted_filter_expr,
visited_rows: HashSet::new(),
offset: 0,
deleted_offset: 0,
exhausted: false,
}
}
fn update_internal_state(
&mut self,
batch: &RecordBatch,
random_state: &RandomState,
) -> Result<()> {
self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
self.hashes_buffer.resize(batch.num_rows(), 0);
update_hash(
&self.on,
batch,
&mut self.hashmap,
self.offset,
random_state,
&mut self.hashes_buffer,
)?;
self.row_hash_values.extend(self.hashes_buffer.iter());
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn join_with_probe_batch(
&mut self,
schema: &SchemaRef,
join_type: JoinType,
on_probe: &[Column],
filter: &JoinFilter,
probe_batch: &RecordBatch,
probe_visited: &mut HashSet<usize>,
probe_offset: usize,
column_indices: &[ColumnIndex],
random_state: &RandomState,
null_equals_null: bool,
) -> Result<Option<RecordBatch>> {
if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(Some(RecordBatch::new_empty(schema.clone())));
}
let (build_indices, probe_indices) = build_join_indices(
probe_batch,
&self.hashmap,
&self.input_buffer,
&self.on,
on_probe,
Some(filter),
random_state,
null_equals_null,
&mut self.hashes_buffer,
Some(self.deleted_offset),
self.build_side,
)?;
if need_to_produce_result_in_final(self.build_side, join_type) {
record_visited_indices(
&mut self.visited_rows,
self.deleted_offset,
&build_indices,
);
}
if need_to_produce_result_in_final(self.build_side.negate(), join_type) {
record_visited_indices(probe_visited, probe_offset, &probe_indices);
}
if matches!(
join_type,
JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftSemi
| JoinType::RightSemi
) {
Ok(None)
} else {
build_batch_from_indices(
schema,
&self.input_buffer,
probe_batch,
build_indices,
probe_indices,
column_indices,
self.build_side,
)
.map(Some)
}
}
fn build_side_determined_results(
&self,
output_schema: &SchemaRef,
prune_length: usize,
probe_schema: SchemaRef,
join_type: JoinType,
column_indices: &[ColumnIndex],
) -> Result<Option<RecordBatch>> {
if need_to_produce_result_in_final(self.build_side, join_type) {
let (build_indices, probe_indices) = calculate_indices_by_join_type(
self.build_side,
prune_length,
&self.visited_rows,
self.deleted_offset,
join_type,
)?;
let empty_probe_batch = RecordBatch::new_empty(probe_schema);
build_batch_from_indices(
output_schema.as_ref(),
&self.input_buffer,
&empty_probe_batch,
build_indices,
probe_indices,
column_indices,
self.build_side,
)
.map(Some)
} else {
Ok(None)
}
}
fn prune_with_probe_batch(
&mut self,
schema: &SchemaRef,
probe_batch: &RecordBatch,
probe_side_sorted_filter_expr: &mut SortedFilterExpr,
join_type: JoinType,
column_indices: &[ColumnIndex],
physical_expr_graph: &mut ExprIntervalGraph,
) -> Result<Option<RecordBatch>> {
if self.input_buffer.num_rows() == 0 {
return Ok(None);
}
let mut filter_intervals = vec![
(
self.sorted_filter_expr.node_index(),
self.sorted_filter_expr.interval().clone(),
),
(
probe_side_sorted_filter_expr.node_index(),
probe_side_sorted_filter_expr.interval().clone(),
),
];
physical_expr_graph.update_ranges(&mut filter_intervals)?;
let calculated_build_side_interval = filter_intervals.remove(0).1;
if calculated_build_side_interval.eq(self.sorted_filter_expr.interval()) {
return Ok(None);
}
self.sorted_filter_expr
.set_interval(calculated_build_side_interval);
let prune_length =
determine_prune_length(&self.input_buffer, &self.sorted_filter_expr)?;
if prune_length == 0 {
return Ok(None);
}
let result = self.build_side_determined_results(
schema,
prune_length,
probe_batch.schema(),
join_type,
column_indices,
);
prune_hash_values(
prune_length,
&mut self.hashmap,
&mut self.row_hash_values,
self.deleted_offset as u64,
)?;
for row in self.deleted_offset..(self.deleted_offset + prune_length) {
self.visited_rows.remove(&row);
}
self.input_buffer = self
.input_buffer
.slice(prune_length, self.input_buffer.num_rows() - prune_length);
self.deleted_offset += prune_length;
result
}
}
fn combine_two_batches(
output_schema: &SchemaRef,
left_batch: Option<RecordBatch>,
right_batch: Option<RecordBatch>,
) -> Result<Option<RecordBatch>> {
match (left_batch, right_batch) {
(Some(batch), None) | (None, Some(batch)) => {
Ok(Some(batch))
}
(Some(left_batch), Some(right_batch)) => {
concat_batches(output_schema, &[left_batch, right_batch])
.map_err(DataFusionError::ArrowError)
.map(Some)
}
(None, None) => {
Ok(None)
}
}
}
impl SymmetricHashJoinStream {
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
if self.final_result {
return Poll::Ready(None);
}
if self.right.exhausted && self.left.exhausted {
let left_result = self.left.build_side_determined_results(
&self.schema,
self.left.input_buffer.num_rows(),
self.right.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
let right_result = self.right.build_side_determined_results(
&self.schema,
self.right.input_buffer.num_rows(),
self.left.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
self.final_result = true;
let result =
combine_two_batches(&self.schema, left_result, right_result)?;
if let Some(batch) = &result {
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
return Poll::Ready(Ok(result).transpose());
} else {
continue;
}
}
let (
input_stream,
probe_hash_joiner,
build_hash_joiner,
build_join_side,
probe_side_metrics,
) = if self.probe_side.eq(&JoinSide::Left) {
(
&mut self.left_stream,
&mut self.left,
&mut self.right,
JoinSide::Right,
&mut self.metrics.left,
)
} else {
(
&mut self.right_stream,
&mut self.right,
&mut self.left,
JoinSide::Left,
&mut self.metrics.right,
)
};
match input_stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(probe_batch))) => {
probe_side_metrics.input_batches.add(1);
probe_side_metrics.input_rows.add(probe_batch.num_rows());
probe_hash_joiner
.update_internal_state(&probe_batch, &self.random_state)?;
calculate_filter_expr_intervals(
&build_hash_joiner.input_buffer,
&mut build_hash_joiner.sorted_filter_expr,
&probe_batch,
&mut probe_hash_joiner.sorted_filter_expr,
)?;
let equal_result = build_hash_joiner.join_with_probe_batch(
&self.schema,
self.join_type,
&probe_hash_joiner.on,
&self.filter,
&probe_batch,
&mut probe_hash_joiner.visited_rows,
probe_hash_joiner.offset,
&self.column_indices,
&self.random_state,
self.null_equals_null,
)?;
probe_hash_joiner.offset += probe_batch.num_rows();
let anti_result = build_hash_joiner.prune_with_probe_batch(
&self.schema,
&probe_batch,
&mut probe_hash_joiner.sorted_filter_expr,
self.join_type,
&self.column_indices,
&mut self.physical_expr_graph,
)?;
let result =
combine_two_batches(&self.schema, equal_result, anti_result)?;
if !build_hash_joiner.exhausted {
self.probe_side = build_join_side;
}
if let Some(batch) = &result {
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
return Poll::Ready(Ok(result).transpose());
}
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
probe_hash_joiner.exhausted = true;
self.probe_side = build_join_side;
}
Poll::Pending => {
if !build_hash_joiner.exhausted {
self.probe_side = build_join_side;
} else {
return Poll::Pending;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use arrow::array::ArrayRef;
use arrow::array::{Int32Array, TimestampNanosecondArray};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::util::pretty::pretty_format_batches;
use rstest::*;
use tempfile::TempDir;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, col, Column};
use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numeric_expr;
use datafusion_physical_expr::PhysicalExpr;
use crate::physical_plan::joins::{
hash_join_utils::tests::complicated_filter, HashJoinExec, PartitionMode,
};
use crate::physical_plan::{
collect, common, memory::MemoryExec, repartition::RepartitionExec,
};
use crate::prelude::{SessionConfig, SessionContext};
use crate::test_util;
use super::*;
const TABLE_SIZE: i32 = 1_000;
fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) {
let first_formatted = pretty_format_batches(collected_1).unwrap().to_string();
let second_formatted = pretty_format_batches(collected_2).unwrap().to_string();
let mut first_formatted_sorted: Vec<&str> =
first_formatted.trim().lines().collect();
first_formatted_sorted.sort_unstable();
let mut second_formatted_sorted: Vec<&str> =
second_formatted.trim().lines().collect();
second_formatted_sorted.sort_unstable();
for (i, (first_line, second_line)) in first_formatted_sorted
.iter()
.zip(&second_formatted_sorted)
.enumerate()
{
assert_eq!((i, first_line), (i, second_line));
}
}
#[allow(clippy::too_many_arguments)]
async fn partitioned_sym_join_with_filter(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: JoinFilter,
join_type: &JoinType,
null_equals_null: bool,
context: Arc<TaskContext>,
) -> Result<Vec<RecordBatch>> {
let partition_count = 4;
let left_expr = on
.iter()
.map(|(l, _)| Arc::new(l.clone()) as _)
.collect::<Vec<_>>();
let right_expr = on
.iter()
.map(|(_, r)| Arc::new(r.clone()) as _)
.collect::<Vec<_>>();
let join = SymmetricHashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
left,
Partitioning::Hash(left_expr, partition_count),
)?),
Arc::new(RepartitionExec::try_new(
right,
Partitioning::Hash(right_expr, partition_count),
)?),
on,
filter,
join_type,
null_equals_null,
)?;
let mut batches = vec![];
for i in 0..partition_count {
let stream = join.execute(i, context.clone())?;
let more_batches = common::collect(stream).await?;
batches.extend(
more_batches
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>(),
);
}
Ok(batches)
}
#[allow(clippy::too_many_arguments)]
async fn partitioned_hash_join_with_filter(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: JoinFilter,
join_type: &JoinType,
null_equals_null: bool,
context: Arc<TaskContext>,
) -> Result<Vec<RecordBatch>> {
let partition_count = 4;
let (left_expr, right_expr) = on
.iter()
.map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
.unzip();
let join = HashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
left,
Partitioning::Hash(left_expr, partition_count),
)?),
Arc::new(RepartitionExec::try_new(
right,
Partitioning::Hash(right_expr, partition_count),
)?),
on,
Some(filter),
join_type,
PartitionMode::Partitioned,
null_equals_null,
)?;
let mut batches = vec![];
for i in 0..partition_count {
let stream = join.execute(i, context.clone())?;
let more_batches = common::collect(stream).await?;
batches.extend(
more_batches
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>(),
);
}
Ok(batches)
}
pub fn split_record_batches(
batch: &RecordBatch,
batch_size: usize,
) -> Result<Vec<RecordBatch>> {
let row_num = batch.num_rows();
let number_of_batch = row_num / batch_size;
let mut sizes = vec![batch_size; number_of_batch];
sizes.push(row_num - (batch_size * number_of_batch));
let mut result = vec![];
for (i, size) in sizes.iter().enumerate() {
result.push(batch.slice(i * batch_size, *size));
}
Ok(result)
}
fn join_expr_tests_fixture(
expr_id: usize,
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
match expr_id {
0 => gen_conjunctive_numeric_expr(
left_col,
right_col,
Operator::Plus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
1,
5,
3,
10,
),
1 => gen_conjunctive_numeric_expr(
left_col,
right_col,
Operator::Minus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
1,
5,
3,
10,
),
2 => gen_conjunctive_numeric_expr(
left_col,
right_col,
Operator::Minus,
Operator::Plus,
Operator::Minus,
Operator::Plus,
1,
5,
3,
10,
),
3 => gen_conjunctive_numeric_expr(
left_col,
right_col,
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Plus,
10,
5,
3,
10,
),
4 => gen_conjunctive_numeric_expr(
left_col,
right_col,
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
10,
5,
30,
3,
),
_ => unreachable!(),
}
}
fn build_sides_record_batches(
table_size: i32,
key_cardinality: (i32, i32),
) -> Result<(RecordBatch, RecordBatch)> {
let null_ratio: f64 = 0.4;
let initial_range = 0..table_size;
let index = (table_size as f64 * null_ratio).round() as i32;
let rest_of = index..table_size;
let ordered: ArrayRef = Arc::new(Int32Array::from_iter(
initial_range.clone().collect::<Vec<i32>>(),
));
let ordered_des = Arc::new(Int32Array::from_iter(
initial_range.clone().rev().collect::<Vec<i32>>(),
));
let cardinality = Arc::new(Int32Array::from_iter(
initial_range.clone().map(|x| x % 4).collect::<Vec<i32>>(),
));
let cardinality_key = Arc::new(Int32Array::from_iter(
initial_range
.clone()
.map(|x| x % key_cardinality.0)
.collect::<Vec<i32>>(),
));
let ordered_asc_null_first = Arc::new(Int32Array::from_iter({
std::iter::repeat(None)
.take(index as usize)
.chain(rest_of.clone().map(Some))
.collect::<Vec<Option<i32>>>()
}));
let ordered_asc_null_last = Arc::new(Int32Array::from_iter({
rest_of
.clone()
.map(Some)
.chain(std::iter::repeat(None).take(index as usize))
.collect::<Vec<Option<i32>>>()
}));
let ordered_desc_null_first = Arc::new(Int32Array::from_iter({
std::iter::repeat(None)
.take(index as usize)
.chain(rest_of.rev().map(Some))
.collect::<Vec<Option<i32>>>()
}));
let time = Arc::new(TimestampNanosecondArray::from(
initial_range
.map(|x| 1664264591000000000 + (5000000000 * (x as i64)))
.collect::<Vec<i64>>(),
));
let left = RecordBatch::try_from_iter(vec![
("la1", ordered.clone()),
("lb1", cardinality.clone()),
("lc1", cardinality_key.clone()),
("lt1", time.clone()),
("la2", ordered.clone()),
("la1_des", ordered_des.clone()),
("l_asc_null_first", ordered_asc_null_first.clone()),
("l_asc_null_last", ordered_asc_null_last.clone()),
("l_desc_null_first", ordered_desc_null_first.clone()),
])?;
let right = RecordBatch::try_from_iter(vec![
("ra1", ordered.clone()),
("rb1", cardinality),
("rc1", cardinality_key),
("rt1", time),
("ra2", ordered),
("ra1_des", ordered_des),
("r_asc_null_first", ordered_asc_null_first),
("r_asc_null_last", ordered_asc_null_last),
("r_desc_null_first", ordered_desc_null_first),
])?;
Ok((left, right))
}
fn create_memory_table(
left_batch: RecordBatch,
right_batch: RecordBatch,
left_sorted: Vec<PhysicalSortExpr>,
right_sorted: Vec<PhysicalSortExpr>,
batch_size: usize,
) -> Result<(Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>)> {
Ok((
Arc::new(
MemoryExec::try_new(
&[split_record_batches(&left_batch, batch_size).unwrap()],
left_batch.schema(),
None,
)?
.with_sort_information(left_sorted),
),
Arc::new(
MemoryExec::try_new(
&[split_record_batches(&right_batch, batch_size).unwrap()],
right_batch.schema(),
None,
)?
.with_sort_information(right_sorted),
),
))
}
async fn experiment(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
filter: JoinFilter,
join_type: JoinType,
on: JoinOn,
task_ctx: Arc<TaskContext>,
) -> Result<()> {
let first_batches = partitioned_sym_join_with_filter(
left.clone(),
right.clone(),
on.clone(),
filter.clone(),
&join_type,
false,
task_ctx.clone(),
)
.await?;
let second_batches = partitioned_hash_join_with_filter(
left, right, on, filter, &join_type, false, task_ctx,
)
.await?;
compare_batches(&first_batches, &second_batches);
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_numeric(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(11, 21),
(31, 71),
(99, 12),
)]
cardinality: (i32, i32),
) -> Result<()> {
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: binary(
col("la1", left_schema)?,
Operator::Plus,
col("la2", left_schema)?,
left_schema,
)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 4,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_all_one_ascending_numeric(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(11, 21),
(31, 71),
(99, 12),
)]
cardinality: (i32, i32),
#[values(0, 1, 2, 3, 4)] case_expr: usize,
) -> Result<()> {
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn single_test() -> Result<()> {
let case_expr = 1;
let cardinality = (11, 21);
let join_type = JoinType::Full;
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1_des", left_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1_des", right_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 5,
side: JoinSide::Left,
},
ColumnIndex {
index: 5,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_all_one_descending_numeric_particular(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(11, 21),
(31, 71),
(99, 12),
)]
cardinality: (i32, i32),
#[values(0, 1, 2, 3, 4)] case_expr: usize,
) -> Result<()> {
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1_des", left_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1_des", right_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 5,
side: JoinSide::Left,
},
ColumnIndex {
index: 5,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn join_change_in_planner() -> Result<()> {
let config = SessionConfig::new().with_target_partitions(1);
let ctx = SessionContext::with_config(config);
let tmp_dir = TempDir::new().unwrap();
let left_file_path = tmp_dir.path().join("left.csv");
File::create(left_file_path.clone()).unwrap();
test_util::test_create_unbounded_sorted_file(
&ctx,
left_file_path.clone(),
"left",
)
.await?;
let right_file_path = tmp_dir.path().join("right.csv");
File::create(right_file_path.clone()).unwrap();
test_util::test_create_unbounded_sorted_file(
&ctx,
right_file_path.clone(),
"right",
)
.await?;
let df = ctx.sql("EXPLAIN SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?;
let physical_plan = df.create_physical_plan().await?;
let task_ctx = ctx.task_ctx();
let results = collect(physical_plan.clone(), task_ctx).await.unwrap();
let formatted = pretty_format_batches(&results).unwrap().to_string();
let found = formatted
.lines()
.any(|line| line.contains("SymmetricHashJoinExec"));
assert!(found);
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_first() -> Result<()> {
let join_type = JoinType::Full;
let cardinality = (10, 11);
let case_expr = 1;
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("l_asc_null_first", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("r_asc_null_first", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 6,
side: JoinSide::Left,
},
ColumnIndex {
index: 6,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_last() -> Result<()> {
let join_type = JoinType::Full;
let cardinality = (10, 11);
let case_expr = 1;
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("l_asc_null_last", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("r_asc_null_last", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 7,
side: JoinSide::Left,
},
ColumnIndex {
index: 7,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_first_descending() -> Result<()> {
let join_type = JoinType::Full;
let cardinality = (10, 11);
let case_expr = 1;
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("l_desc_null_first", left_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("r_desc_null_first", right_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 8,
side: JoinSide::Left,
},
ColumnIndex {
index: 8,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
let cardinality = (3, 4);
let join_type = JoinType::Full;
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?;
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 4,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, filter, join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn test_one_side_hash_joiner_visited_rows(
#[values(
(JoinType::Inner, true),
(JoinType::Left,false),
(JoinType::Right, true),
(JoinType::RightSemi, true),
(JoinType::LeftSemi, false),
(JoinType::LeftAnti, false),
(JoinType::RightAnti, true),
(JoinType::Full, false),
)]
case: (JoinType, bool),
) -> Result<()> {
let join_type = case.0;
let should_be_empty = case.1;
let random_state = RandomState::with_seeds(0, 0, 0, 0);
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) = build_sides_record_batches(20, (1, 1))?;
let left_schema = left_batch.schema();
let right_schema = right_batch.schema();
let (schema, join_column_indices) =
build_join_schema(&left_schema, &right_schema, &join_type);
let join_schema = Arc::new(schema);
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1", &left_schema)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", &right_schema)?,
options: SortOptions::default(),
}];
let (left, right) =
create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 10)?;
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
]);
let filter_expr = gen_conjunctive_numeric_expr(
col("0", &intermediate_schema)?,
col("1", &intermediate_schema)?,
Operator::Plus,
Operator::Minus,
Operator::Plus,
Operator::Plus,
0,
3,
0,
3,
);
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
let left_sorted_filter_expr = SortedFilterExpr::new(
PhysicalSortExpr {
expr: col("la1", &left_schema)?,
options: SortOptions::default(),
},
Arc::new(Column::new("0", 0)),
);
let mut left_side_joiner = OneSideHashJoiner::new(
JoinSide::Left,
left_sorted_filter_expr,
vec![Column::new_with_schema("lc1", &left_schema)?],
left_schema,
);
let right_sorted_filter_expr = SortedFilterExpr::new(
PhysicalSortExpr {
expr: col("ra1", &right_schema)?,
options: SortOptions::default(),
},
Arc::new(Column::new("1", 0)),
);
let mut right_side_joiner = OneSideHashJoiner::new(
JoinSide::Right,
right_sorted_filter_expr,
vec![Column::new_with_schema("rc1", &right_schema)?],
right_schema,
);
let mut left_stream = left.execute(0, task_ctx.clone())?;
let mut right_stream = right.execute(0, task_ctx)?;
let initial_left_batch = left_stream.next().await.unwrap()?;
left_side_joiner.update_internal_state(&initial_left_batch, &random_state)?;
assert_eq!(
left_side_joiner.input_buffer.num_rows(),
initial_left_batch.num_rows()
);
let initial_right_batch = right_stream.next().await.unwrap()?;
right_side_joiner.update_internal_state(&initial_right_batch, &random_state)?;
assert_eq!(
right_side_joiner.input_buffer.num_rows(),
initial_right_batch.num_rows()
);
left_side_joiner.join_with_probe_batch(
&join_schema,
join_type,
&right_side_joiner.on,
&filter,
&initial_right_batch,
&mut right_side_joiner.visited_rows,
right_side_joiner.offset,
&join_column_indices,
&random_state,
false,
)?;
assert_eq!(left_side_joiner.visited_rows.is_empty(), should_be_empty);
Ok(())
}
}