use std::any::Any;
use std::fmt::{self, Debug};
use std::sync::Arc;
use std::task::Poll;
use std::{usize, vec};
use crate::common::SharedMemoryReservation;
use crate::joins::hash_join::{equal_rows_arr, update_hash};
use crate::joins::stream_join_utils::{
calculate_filter_expr_intervals, combine_two_batches,
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr,
StreamJoinMetrics,
};
use crate::joins::utils::{
apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
check_join_is_valid, partitioned_join_output_partitioning, ColumnIndex, JoinFilter,
JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult,
};
use crate::{
execution_mode_from_children,
expressions::PhysicalSortExpr,
joins::StreamJoinPartitionMode,
metrics::{ExecutionPlanMetricsSet, MetricsSet},
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use arrow::array::{
ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
UInt64Array,
};
use arrow::compute::concat_batches;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::bisect;
use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result};
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::TaskContext;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
use ahash::RandomState;
use futures::Stream;
use hashbrown::HashSet;
use parking_lot::Mutex;
const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
#[derive(Debug)]
pub struct SymmetricHashJoinExec {
pub(crate) left: Arc<dyn ExecutionPlan>,
pub(crate) right: Arc<dyn ExecutionPlan>,
pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
pub(crate) filter: Option<JoinFilter>,
pub(crate) join_type: JoinType,
random_state: RandomState,
metrics: ExecutionPlanMetricsSet,
column_indices: Vec<ColumnIndex>,
pub(crate) null_equals_null: bool,
pub(crate) left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
pub(crate) right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
mode: StreamJoinPartitionMode,
cache: PlanProperties,
}
impl SymmetricHashJoinExec {
#[allow(clippy::too_many_arguments)]
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
filter: Option<JoinFilter>,
join_type: &JoinType,
null_equals_null: bool,
left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
mode: StreamJoinPartitionMode,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
if on.is_empty() {
return plan_err!(
"On constraints in SymmetricHashJoinExec should be non-empty"
);
}
check_join_is_valid(&left_schema, &right_schema, &on)?;
let (schema, column_indices) =
build_join_schema(&left_schema, &right_schema, join_type);
let random_state = RandomState::with_seeds(0, 0, 0, 0);
let schema = Arc::new(schema);
let cache =
Self::compute_properties(&left, &right, schema.clone(), *join_type, &on);
Ok(SymmetricHashJoinExec {
left,
right,
on,
filter,
join_type: *join_type,
random_state,
metrics: ExecutionPlanMetricsSet::new(),
column_indices,
null_equals_null,
left_sort_exprs,
right_sort_exprs,
mode,
cache,
})
}
fn compute_properties(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
join_type: JoinType,
join_on: JoinOnRef,
) -> PlanProperties {
let eq_properties = join_equivalence_properties(
left.equivalence_properties().clone(),
right.equivalence_properties().clone(),
&join_type,
schema,
&[false, false],
None,
join_on,
);
let left_columns_len = left.schema().fields.len();
let output_partitioning = partitioned_join_output_partitioning(
join_type,
left.output_partitioning(),
right.output_partitioning(),
left_columns_len,
);
let mode = execution_mode_from_children([left, right]);
PlanProperties::new(eq_properties, output_partitioning, mode)
}
pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
&self.left
}
pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
&self.right
}
pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
&self.on
}
pub fn filter(&self) -> Option<&JoinFilter> {
self.filter.as_ref()
}
pub fn join_type(&self) -> &JoinType {
&self.join_type
}
pub fn null_equals_null(&self) -> bool {
self.null_equals_null
}
pub fn partition_mode(&self) -> StreamJoinPartitionMode {
self.mode
}
pub fn left_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
self.left_sort_exprs.as_deref()
}
pub fn right_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
self.right_sort_exprs.as_deref()
}
pub fn check_if_order_information_available(&self) -> Result<bool> {
if let Some(filter) = self.filter() {
let left = self.left();
if let Some(left_ordering) = left.output_ordering() {
let right = self.right();
if let Some(right_ordering) = right.output_ordering() {
let left_convertible = convert_sort_expr_with_filter_schema(
&JoinSide::Left,
filter,
&left.schema(),
&left_ordering[0],
)?
.is_some();
let right_convertible = convert_sort_expr_with_filter_schema(
&JoinSide::Right,
filter,
&right.schema(),
&right_ordering[0],
)?
.is_some();
return Ok(left_convertible && right_convertible);
}
}
}
Ok(false)
}
}
impl DisplayAs for SymmetricHashJoinExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let display_filter = self.filter.as_ref().map_or_else(
|| "".to_string(),
|f| format!(", filter={}", f.expression()),
);
let on = self
.on
.iter()
.map(|(c1, c2)| format!("({}, {})", c1, c2))
.collect::<Vec<String>>()
.join(", ");
write!(
f,
"SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
self.mode, self.join_type, on, display_filter
)
}
}
}
}
impl ExecutionPlan for SymmetricHashJoinExec {
fn name(&self) -> &'static str {
"SymmetricHashJoinExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn required_input_distribution(&self) -> Vec<Distribution> {
match self.mode {
StreamJoinPartitionMode::Partitioned => {
let (left_expr, right_expr) = self
.on
.iter()
.map(|(l, r)| (l.clone() as _, r.clone() as _))
.unzip();
vec![
Distribution::HashPartitioned(left_expr),
Distribution::HashPartitioned(right_expr),
]
}
StreamJoinPartitionMode::SinglePartition => {
vec![Distribution::SinglePartition, Distribution::SinglePartition]
}
}
}
fn required_input_ordering(&self) -> Vec<Option<Vec<PhysicalSortRequirement>>> {
vec![
self.left_sort_exprs
.as_ref()
.map(PhysicalSortRequirement::from_sort_exprs),
self.right_sort_exprs
.as_ref()
.map(PhysicalSortRequirement::from_sort_exprs),
]
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.left.clone(), self.right.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(SymmetricHashJoinExec::try_new(
children[0].clone(),
children[1].clone(),
self.on.clone(),
self.filter.clone(),
&self.join_type,
self.null_equals_null,
self.left_sort_exprs.clone(),
self.right_sort_exprs.clone(),
self.mode,
)?))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Result<Statistics> {
Ok(Statistics::new_unknown(&self.schema()))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let left_partitions = self.left.output_partitioning().partition_count();
let right_partitions = self.right.output_partitioning().partition_count();
if left_partitions != right_partitions {
return internal_err!(
"Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
consider using RepartitionExec"
);
}
let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) {
(Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
let (left, right, graph) = prepare_sorted_exprs(
filter,
&self.left,
&self.right,
left_sort_exprs,
right_sort_exprs,
)?;
(Some(left), Some(right), Some(graph))
}
_ => (None, None, None),
};
let (on_left, on_right) = self.on.iter().cloned().unzip();
let left_side_joiner =
OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
let right_side_joiner =
OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
let left_stream = self.left.execute(partition, context.clone())?;
let right_stream = self.right.execute(partition, context.clone())?;
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
.register(context.memory_pool()),
));
if let Some(g) = graph.as_ref() {
reservation.lock().try_grow(g.size())?;
}
Ok(Box::pin(SymmetricHashJoinStream {
left_stream,
right_stream,
schema: self.schema(),
filter: self.filter.clone(),
join_type: self.join_type,
random_state: self.random_state.clone(),
left: left_side_joiner,
right: right_side_joiner,
column_indices: self.column_indices.clone(),
metrics: StreamJoinMetrics::new(partition, &self.metrics),
graph,
left_sorted_filter_expr,
right_sorted_filter_expr,
null_equals_null: self.null_equals_null,
state: EagerJoinStreamState::PullRight,
reservation,
}))
}
}
struct SymmetricHashJoinStream {
left_stream: SendableRecordBatchStream,
right_stream: SendableRecordBatchStream,
schema: Arc<Schema>,
filter: Option<JoinFilter>,
join_type: JoinType,
left: OneSideHashJoiner,
right: OneSideHashJoiner,
column_indices: Vec<ColumnIndex>,
graph: Option<ExprIntervalGraph>,
left_sorted_filter_expr: Option<SortedFilterExpr>,
right_sorted_filter_expr: Option<SortedFilterExpr>,
random_state: RandomState,
null_equals_null: bool,
metrics: StreamJoinMetrics,
reservation: SharedMemoryReservation,
state: EagerJoinStreamState,
}
impl RecordBatchStream for SymmetricHashJoinStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
impl Stream for SymmetricHashJoinStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.poll_next_impl(cx)
}
}
fn determine_prune_length(
buffer: &RecordBatch,
build_side_filter_expr: &SortedFilterExpr,
) -> Result<usize> {
let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
let interval = build_side_filter_expr.interval();
let batch_arr = origin_sorted_expr
.expr
.evaluate(buffer)?
.into_array(buffer.num_rows())?;
let target = if origin_sorted_expr.options.descending {
interval.upper().clone()
} else {
interval.lower().clone()
};
bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
}
fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
if build_side == JoinSide::Left {
matches!(
join_type,
JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi
)
} else {
matches!(
join_type,
JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi
)
}
}
fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
build_side: JoinSide,
prune_length: usize,
visited_rows: &HashSet<usize>,
deleted_offset: usize,
join_type: JoinType,
) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
where
NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
{
let result = match (build_side, join_type) {
(JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
| (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
| (_, JoinType::Full) => {
let build_unmatched_indices =
get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
let probe_indices = builder.finish();
(build_unmatched_indices, probe_indices)
}
(JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
let build_unmatched_indices =
get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
let mut builder =
PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
builder.append_nulls(build_unmatched_indices.len());
let probe_indices = builder.finish();
(build_unmatched_indices, probe_indices)
}
_ => unreachable!(),
};
Ok(result)
}
pub(crate) fn build_side_determined_results(
build_hash_joiner: &OneSideHashJoiner,
output_schema: &SchemaRef,
prune_length: usize,
probe_schema: SchemaRef,
join_type: JoinType,
column_indices: &[ColumnIndex],
) -> Result<Option<RecordBatch>> {
if prune_length > 0
&& need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
{
let (build_indices, probe_indices) = calculate_indices_by_join_type(
build_hash_joiner.build_side,
prune_length,
&build_hash_joiner.visited_rows,
build_hash_joiner.deleted_offset,
join_type,
)?;
let empty_probe_batch = RecordBatch::new_empty(probe_schema);
build_batch_from_indices(
output_schema.as_ref(),
&build_hash_joiner.input_buffer,
&empty_probe_batch,
&build_indices,
&probe_indices,
column_indices,
build_hash_joiner.build_side,
)
.map(|batch| (batch.num_rows() > 0).then_some(batch))
} else {
Ok(None)
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn join_with_probe_batch(
build_hash_joiner: &mut OneSideHashJoiner,
probe_hash_joiner: &mut OneSideHashJoiner,
schema: &SchemaRef,
join_type: JoinType,
filter: Option<&JoinFilter>,
probe_batch: &RecordBatch,
column_indices: &[ColumnIndex],
random_state: &RandomState,
null_equals_null: bool,
) -> Result<Option<RecordBatch>> {
if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(None);
}
let (build_indices, probe_indices) = lookup_join_hashmap(
&build_hash_joiner.hashmap,
&build_hash_joiner.input_buffer,
probe_batch,
&build_hash_joiner.on,
&probe_hash_joiner.on,
random_state,
null_equals_null,
&mut build_hash_joiner.hashes_buffer,
Some(build_hash_joiner.deleted_offset),
)?;
let (build_indices, probe_indices) = if let Some(filter) = filter {
apply_join_filter_to_indices(
&build_hash_joiner.input_buffer,
probe_batch,
build_indices,
probe_indices,
filter,
build_hash_joiner.build_side,
)?
} else {
(build_indices, probe_indices)
};
if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
record_visited_indices(
&mut build_hash_joiner.visited_rows,
build_hash_joiner.deleted_offset,
&build_indices,
);
}
if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
record_visited_indices(
&mut probe_hash_joiner.visited_rows,
probe_hash_joiner.offset,
&probe_indices,
);
}
if matches!(
join_type,
JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftSemi
| JoinType::RightSemi
) {
Ok(None)
} else {
build_batch_from_indices(
schema,
&build_hash_joiner.input_buffer,
probe_batch,
&build_indices,
&probe_indices,
column_indices,
build_hash_joiner.build_side,
)
.map(|batch| (batch.num_rows() > 0).then_some(batch))
}
}
#[allow(clippy::too_many_arguments)]
fn lookup_join_hashmap(
build_hashmap: &PruningJoinHashMap,
build_batch: &RecordBatch,
probe_batch: &RecordBatch,
build_on: &[PhysicalExprRef],
probe_on: &[PhysicalExprRef],
random_state: &RandomState,
null_equals_null: bool,
hashes_buffer: &mut Vec<u64>,
deleted_offset: Option<usize>,
) -> Result<(UInt64Array, UInt32Array)> {
let keys_values = probe_on
.iter()
.map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))
.collect::<Result<Vec<_>>>()?;
let build_join_values = build_on
.iter()
.map(|c| c.evaluate(build_batch)?.into_array(build_batch.num_rows()))
.collect::<Result<Vec<_>>>()?;
hashes_buffer.clear();
hashes_buffer.resize(probe_batch.num_rows(), 0);
let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
let (mut matched_probe, mut matched_build) = build_hashmap
.get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset);
matched_probe.as_slice_mut().reverse();
matched_build.as_slice_mut().reverse();
let build_indices: UInt64Array =
PrimitiveArray::new(matched_build.finish().into(), None);
let probe_indices: UInt32Array =
PrimitiveArray::new(matched_probe.finish().into(), None);
let (build_indices, probe_indices) = equal_rows_arr(
&build_indices,
&probe_indices,
&build_join_values,
&keys_values,
null_equals_null,
)?;
Ok((build_indices, probe_indices))
}
pub struct OneSideHashJoiner {
build_side: JoinSide,
pub input_buffer: RecordBatch,
pub(crate) on: Vec<PhysicalExprRef>,
pub(crate) hashmap: PruningJoinHashMap,
pub(crate) hashes_buffer: Vec<u64>,
pub(crate) visited_rows: HashSet<usize>,
pub(crate) offset: usize,
pub(crate) deleted_offset: usize,
}
impl OneSideHashJoiner {
pub fn size(&self) -> usize {
let mut size = 0;
size += std::mem::size_of_val(self);
size += std::mem::size_of_val(&self.build_side);
size += self.input_buffer.get_array_memory_size();
size += std::mem::size_of_val(&self.on);
size += self.hashmap.size();
size += self.hashes_buffer.capacity() * std::mem::size_of::<u64>();
size += self.visited_rows.capacity() * std::mem::size_of::<usize>();
size += std::mem::size_of_val(&self.offset);
size += std::mem::size_of_val(&self.deleted_offset);
size
}
pub fn new(
build_side: JoinSide,
on: Vec<PhysicalExprRef>,
schema: SchemaRef,
) -> Self {
Self {
build_side,
input_buffer: RecordBatch::new_empty(schema),
on,
hashmap: PruningJoinHashMap::with_capacity(0),
hashes_buffer: vec![],
visited_rows: HashSet::new(),
offset: 0,
deleted_offset: 0,
}
}
pub(crate) fn update_internal_state(
&mut self,
batch: &RecordBatch,
random_state: &RandomState,
) -> Result<()> {
self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
self.hashes_buffer.resize(batch.num_rows(), 0);
update_hash(
&self.on,
batch,
&mut self.hashmap,
self.offset,
random_state,
&mut self.hashes_buffer,
self.deleted_offset,
false,
)?;
Ok(())
}
pub(crate) fn calculate_prune_length_with_probe_batch(
&mut self,
build_side_sorted_filter_expr: &mut SortedFilterExpr,
probe_side_sorted_filter_expr: &mut SortedFilterExpr,
graph: &mut ExprIntervalGraph,
) -> Result<usize> {
if self.input_buffer.num_rows() == 0 {
return Ok(0);
}
let mut filter_intervals = vec![];
for expr in [
&build_side_sorted_filter_expr,
&probe_side_sorted_filter_expr,
] {
filter_intervals.push((expr.node_index(), expr.interval().clone()))
}
graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?;
let calculated_build_side_interval = filter_intervals.remove(0).1;
if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
return Ok(0);
}
build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
}
pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
self.hashmap.prune_hash_values(
prune_length,
self.deleted_offset as u64,
HASHMAP_SHRINK_SCALE_FACTOR,
);
for row in self.deleted_offset..(self.deleted_offset + prune_length) {
self.visited_rows.remove(&row);
}
self.input_buffer = self
.input_buffer
.slice(prune_length, self.input_buffer.num_rows() - prune_length);
self.deleted_offset += prune_length;
Ok(())
}
}
impl EagerJoinStream for SymmetricHashJoinStream {
fn process_batch_from_right(
&mut self,
batch: RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Right)
.map(|maybe_batch| {
if maybe_batch.is_some() {
StatefulStreamResult::Ready(maybe_batch)
} else {
StatefulStreamResult::Continue
}
})
}
fn process_batch_from_left(
&mut self,
batch: RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Left)
.map(|maybe_batch| {
if maybe_batch.is_some() {
StatefulStreamResult::Ready(maybe_batch)
} else {
StatefulStreamResult::Continue
}
})
}
fn process_batch_after_left_end(
&mut self,
right_batch: RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_right(right_batch)
}
fn process_batch_after_right_end(
&mut self,
left_batch: RecordBatch,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_left(left_batch)
}
fn process_batches_before_finalization(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let left_result = build_side_determined_results(
&self.left,
&self.schema,
self.left.input_buffer.num_rows(),
self.right.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
let right_result = build_side_determined_results(
&self.right,
&self.schema,
self.right.input_buffer.num_rows(),
self.left.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
let result = combine_two_batches(&self.schema, left_result, right_result)?;
if let Some(batch) = &result {
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
return Ok(StatefulStreamResult::Ready(result));
}
Ok(StatefulStreamResult::Continue)
}
fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
&mut self.right_stream
}
fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
&mut self.left_stream
}
fn set_state(&mut self, state: EagerJoinStreamState) {
self.state = state;
}
fn state(&mut self) -> EagerJoinStreamState {
self.state.clone()
}
}
impl SymmetricHashJoinStream {
fn size(&self) -> usize {
let mut size = 0;
size += std::mem::size_of_val(&self.schema);
size += std::mem::size_of_val(&self.filter);
size += std::mem::size_of_val(&self.join_type);
size += self.left.size();
size += self.right.size();
size += std::mem::size_of_val(&self.column_indices);
size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
size += std::mem::size_of_val(&self.left_sorted_filter_expr);
size += std::mem::size_of_val(&self.right_sorted_filter_expr);
size += std::mem::size_of_val(&self.random_state);
size += std::mem::size_of_val(&self.null_equals_null);
size += std::mem::size_of_val(&self.metrics);
size
}
fn perform_join_for_given_side(
&mut self,
probe_batch: RecordBatch,
probe_side: JoinSide,
) -> Result<Option<RecordBatch>> {
let (
probe_hash_joiner,
build_hash_joiner,
probe_side_sorted_filter_expr,
build_side_sorted_filter_expr,
probe_side_metrics,
) = if probe_side.eq(&JoinSide::Left) {
(
&mut self.left,
&mut self.right,
&mut self.left_sorted_filter_expr,
&mut self.right_sorted_filter_expr,
&mut self.metrics.left,
)
} else {
(
&mut self.right,
&mut self.left,
&mut self.right_sorted_filter_expr,
&mut self.left_sorted_filter_expr,
&mut self.metrics.right,
)
};
probe_side_metrics.input_batches.add(1);
probe_side_metrics.input_rows.add(probe_batch.num_rows());
probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?;
let equal_result = join_with_probe_batch(
build_hash_joiner,
probe_hash_joiner,
&self.schema,
self.join_type,
self.filter.as_ref(),
&probe_batch,
&self.column_indices,
&self.random_state,
self.null_equals_null,
)?;
probe_hash_joiner.offset += probe_batch.num_rows();
let anti_result = if let (
Some(build_side_sorted_filter_expr),
Some(probe_side_sorted_filter_expr),
Some(graph),
) = (
build_side_sorted_filter_expr.as_mut(),
probe_side_sorted_filter_expr.as_mut(),
self.graph.as_mut(),
) {
calculate_filter_expr_intervals(
&build_hash_joiner.input_buffer,
build_side_sorted_filter_expr,
&probe_batch,
probe_side_sorted_filter_expr,
)?;
let prune_length = build_hash_joiner
.calculate_prune_length_with_probe_batch(
build_side_sorted_filter_expr,
probe_side_sorted_filter_expr,
graph,
)?;
let result = build_side_determined_results(
build_hash_joiner,
&self.schema,
prune_length,
probe_batch.schema(),
self.join_type,
&self.column_indices,
)?;
build_hash_joiner.prune_internal_state(prune_length)?;
result
} else {
None
};
let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
let capacity = self.size();
self.metrics.stream_memory_usage.set(capacity);
self.reservation.lock().try_resize(capacity)?;
if let Some(batch) = &result {
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Mutex;
use super::*;
use crate::joins::test_utils::{
build_sides_record_batches, compare_batches, complicated_filter,
create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
partitioned_sym_join_with_filter, split_record_batches,
};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion_common::ScalarValue;
use datafusion_execution::config::SessionConfig;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, col, lit, Column};
use once_cell::sync::Lazy;
use rstest::*;
const TABLE_SIZE: i32 = 30;
type TableKey = (i32, i32, usize); type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); static TABLE_CACHE: Lazy<Mutex<HashMap<TableKey, TableValue>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
fn get_or_create_table(
cardinality: (i32, i32),
batch_size: usize,
) -> Result<TableValue> {
{
let cache = TABLE_CACHE.lock().unwrap();
if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
return Ok(table.clone());
}
}
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let (left_partition, right_partition) = (
split_record_batches(&left_batch, batch_size)?,
split_record_batches(&right_batch, batch_size)?,
);
let mut cache = TABLE_CACHE.lock().unwrap();
cache.insert(
(cardinality.0, cardinality.1, batch_size),
(left_partition.clone(), right_partition.clone()),
);
Ok((left_partition, right_partition))
}
pub async fn experiment(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
filter: Option<JoinFilter>,
join_type: JoinType,
on: JoinOn,
task_ctx: Arc<TaskContext>,
) -> Result<()> {
let first_batches = partitioned_sym_join_with_filter(
left.clone(),
right.clone(),
on.clone(),
filter.clone(),
&join_type,
false,
task_ctx.clone(),
)
.await?;
let second_batches = partitioned_hash_join_with_filter(
left, right, on, filter, &join_type, false, task_ctx,
)
.await?;
compare_batches(&first_batches, &second_batches);
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_numeric(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: binary(
col("la1", left_schema)?,
Operator::Plus,
col("la2", left_schema)?,
left_schema,
)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
binary(
col("lc1", left_schema)?,
Operator::Plus,
lit(ScalarValue::Int32(Some(1))),
left_schema,
)?,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 4,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_all_one_ascending_numeric(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_without_sort_information(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let (left, right) =
create_memory_table(left_partition, right_partition, vec![], vec![])?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 5,
side: JoinSide::Left,
},
ColumnIndex {
index: 5,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_without_filter(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let (left, right) =
create_memory_table(left_partition, right_partition, vec![], vec![])?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
experiment(left, right, None, join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn join_all_one_descending_numeric_particular(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1_des", left_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1_des", right_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 5,
side: JoinSide::Left,
},
ColumnIndex {
index: 5,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_first() -> Result<()> {
let join_type = JoinType::Full;
let case_expr = 1;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("l_asc_null_first", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("r_asc_null_first", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 6,
side: JoinSide::Left,
},
ColumnIndex {
index: 6,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_last() -> Result<()> {
let join_type = JoinType::Full;
let case_expr = 1;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("l_asc_null_last", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("r_asc_null_last", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 7,
side: JoinSide::Left,
},
ColumnIndex {
index: 7,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn build_null_columns_first_descending() -> Result<()> {
let join_type = JoinType::Full;
let cardinality = (10, 11);
let case_expr = 1;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("l_desc_null_first", left_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("r_desc_null_first", right_schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 8,
side: JoinSide::Left,
},
ColumnIndex {
index: 8,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
let cardinality = (3, 4);
let join_type = JoinType::Full;
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 4,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
let cardinality = (3, 4);
let join_type = JoinType::Full;
let config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![
vec![PhysicalSortExpr {
expr: col("la1", left_schema)?,
options: SortOptions::default(),
}],
vec![PhysicalSortExpr {
expr: col("la2", left_schema)?,
options: SortOptions::default(),
}],
];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ra1", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
left_sorted,
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
Field::new("2", DataType::Int32, true),
]);
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 4,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn testing_with_temporal_columns(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
#[values(0, 1, 2)] case_expr: usize,
) -> Result<()> {
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("lt1", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("rt1", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let intermediate_schema = Schema::new(vec![
Field::new(
"left",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"right",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
]);
let filter_expr = join_expr_tests_fixture_temporal(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 3,
side: JoinSide::Left,
},
ColumnIndex {
index: 3,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn test_with_interval_columns(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
) -> Result<()> {
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("li1", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ri1", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
]);
let filter_expr = join_expr_tests_fixture_temporal(
0,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 9,
side: JoinSide::Left,
},
ColumnIndex {
index: 9,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn testing_ascending_float_pruning(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(12, 17),
)]
cardinality: (i32, i32),
#[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
) -> Result<()> {
let session_config = SessionConfig::new().with_repartition_joins(false);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let left_sorted = vec![PhysicalSortExpr {
expr: col("l_float", left_schema)?,
options: SortOptions::default(),
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("r_float", right_schema)?,
options: SortOptions::default(),
}];
let (left, right) = create_memory_table(
left_partition,
right_partition,
vec![left_sorted],
vec![right_sorted],
)?;
let on = vec![(
Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Float64, true),
Field::new("right", DataType::Float64, true),
]);
let filter_expr = join_expr_tests_fixture_f64(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
);
let column_indices = vec![
ColumnIndex {
index: 10, side: JoinSide::Left,
},
ColumnIndex {
index: 10, side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
}