use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::task::Poll;
use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus};
use crate::joins::Map;
use crate::joins::MapOffset;
use crate::joins::PartitionMode;
use crate::joins::hash_join::exec::JoinLeftData;
use crate::joins::hash_join::shared_bounds::{
PartitionBounds, PartitionBuildData, SharedBuildAccumulator,
};
use crate::joins::utils::{
OnceFut, equal_rows_arr, get_final_indices_from_shared_bitmap,
};
use crate::{
RecordBatchStream, SendableRecordBatchStream, handle_state,
hash_utils::create_hashes,
joins::utils::{
BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType,
StatefulStreamResult, adjust_indices_by_join_type, apply_join_filter_to_indices,
build_batch_empty_build_side, build_batch_from_indices,
need_produce_result_in_final,
},
};
use arrow::array::{Array, ArrayRef, UInt32Array, UInt64Array};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{
JoinSide, JoinType, NullEquality, Result, internal_datafusion_err, internal_err,
};
use datafusion_physical_expr::PhysicalExprRef;
use ahash::RandomState;
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use futures::{Stream, StreamExt, ready};
pub(super) enum BuildSide {
Initial(BuildSideInitialState),
Ready(BuildSideReadyState),
}
pub(super) struct BuildSideInitialState {
pub(super) left_fut: OnceFut<JoinLeftData>,
}
pub(super) struct BuildSideReadyState {
left_data: Arc<JoinLeftData>,
}
impl BuildSide {
fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> {
match self {
BuildSide::Initial(state) => Ok(state),
_ => internal_err!("Expected build side in initial state"),
}
}
fn try_as_ready(&self) -> Result<&BuildSideReadyState> {
match self {
BuildSide::Ready(state) => Ok(state),
_ => internal_err!("Expected build side in ready state"),
}
}
fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> {
match self {
BuildSide::Ready(state) => Ok(state),
_ => internal_err!("Expected build side in ready state"),
}
}
}
#[derive(Debug, Clone)]
pub(super) enum HashJoinStreamState {
WaitBuildSide,
WaitPartitionBoundsReport,
FetchProbeBatch,
ProcessProbeBatch(ProcessProbeBatchState),
ExhaustedProbeSide,
Completed,
}
impl HashJoinStreamState {
fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> {
match self {
HashJoinStreamState::ProcessProbeBatch(state) => Ok(state),
_ => internal_err!("Expected hash join stream in ProcessProbeBatch state"),
}
}
}
#[derive(Debug, Clone)]
pub(super) struct ProcessProbeBatchState {
batch: RecordBatch,
values: Vec<ArrayRef>,
offset: MapOffset,
joined_probe_idx: Option<usize>,
}
impl ProcessProbeBatchState {
fn advance(&mut self, offset: MapOffset, joined_probe_idx: Option<usize>) {
self.offset = offset;
if joined_probe_idx.is_some() {
self.joined_probe_idx = joined_probe_idx;
}
}
}
pub(super) struct HashJoinStream {
partition: usize,
schema: Arc<Schema>,
on_right: Vec<PhysicalExprRef>,
filter: Option<JoinFilter>,
join_type: JoinType,
right: SendableRecordBatchStream,
random_state: RandomState,
join_metrics: BuildProbeJoinMetrics,
column_indices: Vec<ColumnIndex>,
null_equality: NullEquality,
state: HashJoinStreamState,
build_side: BuildSide,
batch_size: usize,
hashes_buffer: Vec<u64>,
probe_indices_buffer: Vec<u32>,
build_indices_buffer: Vec<u64>,
right_side_ordered: bool,
build_accumulator: Option<Arc<SharedBuildAccumulator>>,
build_waiter: Option<OnceFut<()>>,
mode: PartitionMode,
output_buffer: LimitedBatchCoalescer,
null_aware: bool,
}
impl RecordBatchStream for HashJoinStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[expect(clippy::too_many_arguments)]
pub(super) fn lookup_join_hashmap(
build_hashmap: &dyn JoinHashMapType,
build_side_values: &[ArrayRef],
probe_side_values: &[ArrayRef],
null_equality: NullEquality,
hashes_buffer: &[u64],
limit: usize,
offset: MapOffset,
probe_indices_buffer: &mut Vec<u32>,
build_indices_buffer: &mut Vec<u64>,
) -> Result<(UInt64Array, UInt32Array, Option<MapOffset>)> {
let next_offset = build_hashmap.get_matched_indices_with_limit_offset(
hashes_buffer,
limit,
offset,
probe_indices_buffer,
build_indices_buffer,
);
let build_indices_unfiltered: UInt64Array =
std::mem::take(build_indices_buffer).into();
let probe_indices_unfiltered: UInt32Array =
std::mem::take(probe_indices_buffer).into();
let (build_indices, probe_indices) = equal_rows_arr(
&build_indices_unfiltered,
&probe_indices_unfiltered,
build_side_values,
probe_side_values,
null_equality,
)?;
*build_indices_buffer = build_indices_unfiltered.into_parts().1.into();
*probe_indices_buffer = probe_indices_unfiltered.into_parts().1.into();
Ok((build_indices, probe_indices, next_offset))
}
#[inline]
fn count_distinct_sorted_indices(indices: &UInt32Array) -> usize {
if indices.is_empty() {
return 0;
}
debug_assert!(indices.null_count() == 0);
let values_buf = indices.values();
let values = values_buf.as_ref();
let mut iter = values.iter();
let Some(&first) = iter.next() else {
return 0;
};
let mut count = 1usize;
let mut last = first;
for &value in iter {
if value != last {
last = value;
count += 1;
}
}
count
}
impl HashJoinStream {
#[expect(clippy::too_many_arguments)]
pub(super) fn new(
partition: usize,
schema: Arc<Schema>,
on_right: Vec<PhysicalExprRef>,
filter: Option<JoinFilter>,
join_type: JoinType,
right: SendableRecordBatchStream,
random_state: RandomState,
join_metrics: BuildProbeJoinMetrics,
column_indices: Vec<ColumnIndex>,
null_equality: NullEquality,
state: HashJoinStreamState,
build_side: BuildSide,
batch_size: usize,
hashes_buffer: Vec<u64>,
right_side_ordered: bool,
build_accumulator: Option<Arc<SharedBuildAccumulator>>,
mode: PartitionMode,
null_aware: bool,
fetch: Option<usize>,
) -> Self {
let output_buffer =
LimitedBatchCoalescer::new(Arc::clone(&schema), batch_size, fetch);
Self {
partition,
schema,
on_right,
filter,
join_type,
right,
random_state,
join_metrics,
column_indices,
null_equality,
state,
build_side,
batch_size,
hashes_buffer,
probe_indices_buffer: Vec::with_capacity(batch_size),
build_indices_buffer: Vec::with_capacity(batch_size),
right_side_ordered,
build_accumulator,
build_waiter: None,
mode,
output_buffer,
null_aware,
}
}
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
if let Some(batch) = self.output_buffer.next_completed_batch() {
return self
.join_metrics
.baseline
.record_poll(Poll::Ready(Some(Ok(batch))));
}
if self.output_buffer.is_finished() {
return Poll::Ready(None);
}
return match self.state {
HashJoinStreamState::WaitBuildSide => {
handle_state!(ready!(self.collect_build_side(cx)))
}
HashJoinStreamState::WaitPartitionBoundsReport => {
handle_state!(ready!(self.wait_for_partition_bounds_report(cx)))
}
HashJoinStreamState::FetchProbeBatch => {
handle_state!(ready!(self.fetch_probe_batch(cx)))
}
HashJoinStreamState::ProcessProbeBatch(_) => {
handle_state!(self.process_probe_batch())
}
HashJoinStreamState::ExhaustedProbeSide => {
handle_state!(self.process_unmatched_build_batch())
}
HashJoinStreamState::Completed if !self.output_buffer.is_empty() => {
self.output_buffer.finish()?;
continue;
}
HashJoinStreamState::Completed => Poll::Ready(None),
};
}
}
fn wait_for_partition_bounds_report(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
if let Some(ref mut fut) = self.build_waiter {
ready!(fut.get_shared(cx))?;
}
self.state = HashJoinStreamState::FetchProbeBatch;
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
fn collect_build_side(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let left_data = ready!(
self.build_side
.try_as_initial_mut()?
.left_fut
.get_shared(cx)
)?;
build_timer.done();
if let Some(ref build_accumulator) = self.build_accumulator {
let build_accumulator = Arc::clone(build_accumulator);
let left_side_partition_id = match self.mode {
PartitionMode::Partitioned => self.partition,
PartitionMode::CollectLeft => 0,
PartitionMode::Auto => unreachable!(
"PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"
),
};
let pushdown = left_data.membership().clone();
let build_data = match self.mode {
PartitionMode::Partitioned => PartitionBuildData::Partitioned {
partition_id: left_side_partition_id,
pushdown,
bounds: left_data
.bounds
.clone()
.unwrap_or_else(|| PartitionBounds::new(vec![])),
},
PartitionMode::CollectLeft => PartitionBuildData::CollectLeft {
pushdown,
bounds: left_data
.bounds
.clone()
.unwrap_or_else(|| PartitionBounds::new(vec![])),
},
PartitionMode::Auto => unreachable!(
"PartitionMode::Auto should not be present at execution time"
),
};
self.build_waiter = Some(OnceFut::new(async move {
build_accumulator.report_build_data(build_data).await
}));
self.state = HashJoinStreamState::WaitPartitionBoundsReport;
} else {
self.state = HashJoinStreamState::FetchProbeBatch;
}
self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
fn fetch_probe_batch(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
match ready!(self.right.poll_next_unpin(cx)) {
None => {
self.state = HashJoinStreamState::ExhaustedProbeSide;
}
Some(Ok(batch)) => {
let keys_values = evaluate_expressions_to_arrays(&self.on_right, &batch)?;
if let Map::HashMap(_) = self.build_side.try_as_ready()?.left_data.map() {
self.hashes_buffer.clear();
self.hashes_buffer.resize(batch.num_rows(), 0);
create_hashes(
&keys_values,
&self.random_state,
&mut self.hashes_buffer,
)?;
}
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
self.state =
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
batch,
values: keys_values,
offset: (0, None),
joined_probe_idx: None,
});
}
Some(Err(err)) => return Poll::Ready(Err(err)),
};
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
fn process_probe_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let state = self.state.try_as_process_probe_batch_mut()?;
let build_side = self.build_side.try_as_ready_mut()?;
self.join_metrics
.probe_hit_rate
.add_total(state.batch.num_rows());
let timer = self.join_metrics.join_time.timer();
if self.null_aware {
if state.batch.num_rows() > 0 {
build_side
.left_data
.probe_side_non_empty
.store(true, Ordering::Relaxed);
}
let probe_key_column = &state.values[0];
if probe_key_column.null_count() > 0 {
build_side
.left_data
.probe_side_has_null
.store(true, Ordering::Relaxed);
}
if build_side
.left_data
.probe_side_has_null
.load(Ordering::Relaxed)
{
timer.done();
self.state = HashJoinStreamState::FetchProbeBatch;
return Ok(StatefulStreamResult::Continue);
}
}
let is_empty = build_side.left_data.map().is_empty();
if is_empty && self.filter.is_none() {
let result = build_batch_empty_build_side(
&self.schema,
build_side.left_data.batch(),
&state.batch,
&self.column_indices,
self.join_type,
)?;
timer.done();
self.state = HashJoinStreamState::FetchProbeBatch;
return Ok(StatefulStreamResult::Ready(Some(result)));
}
let (left_indices, right_indices, next_offset) = match build_side.left_data.map()
{
Map::HashMap(map) => lookup_join_hashmap(
map.as_ref(),
build_side.left_data.values(),
&state.values,
self.null_equality,
&self.hashes_buffer,
self.batch_size,
state.offset,
&mut self.probe_indices_buffer,
&mut self.build_indices_buffer,
)?,
Map::ArrayMap(array_map) => {
let next_offset = array_map.get_matched_indices_with_limit_offset(
&state.values,
self.batch_size,
state.offset,
&mut self.probe_indices_buffer,
&mut self.build_indices_buffer,
)?;
(
UInt64Array::from(self.build_indices_buffer.clone()),
UInt32Array::from(self.probe_indices_buffer.clone()),
next_offset,
)
}
};
let distinct_right_indices_count = count_distinct_sorted_indices(&right_indices);
self.join_metrics
.probe_hit_rate
.add_part(distinct_right_indices_count);
self.join_metrics.avg_fanout.add_part(left_indices.len());
self.join_metrics
.avg_fanout
.add_total(distinct_right_indices_count);
let (left_indices, right_indices) = if let Some(filter) = &self.filter {
apply_join_filter_to_indices(
build_side.left_data.batch(),
&state.batch,
left_indices,
right_indices,
filter,
JoinSide::Left,
None,
self.join_type,
)?
} else {
(left_indices, right_indices)
};
if need_produce_result_in_final(self.join_type) {
let mut bitmap = build_side.left_data.visited_indices_bitmap().lock();
left_indices.iter().flatten().for_each(|x| {
bitmap.set_bit(x as usize, true);
});
}
let last_joined_right_idx = match right_indices.len() {
0 => None,
n => Some(right_indices.value(n - 1) as usize),
};
let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1);
let index_alignment_range_end = if next_offset.is_none() {
state.batch.num_rows()
} else {
last_joined_right_idx.map_or(0, |v| v + 1)
};
let (left_indices, right_indices) = adjust_indices_by_join_type(
left_indices,
right_indices,
index_alignment_range_start..index_alignment_range_end,
self.join_type,
self.right_side_ordered,
)?;
let (build_batch, probe_batch, join_side) =
if self.join_type == JoinType::RightMark {
(&state.batch, build_side.left_data.batch(), JoinSide::Right)
} else {
(build_side.left_data.batch(), &state.batch, JoinSide::Left)
};
let batch = build_batch_from_indices(
&self.schema,
build_batch,
probe_batch,
&left_indices,
&right_indices,
&self.column_indices,
join_side,
self.join_type,
)?;
let push_status = self.output_buffer.push_batch(batch)?;
timer.done();
if push_status == PushBatchStatus::LimitReached {
self.output_buffer.finish()?;
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
}
if next_offset.is_none() {
self.state = HashJoinStreamState::FetchProbeBatch;
} else {
state.advance(
next_offset
.ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?,
last_joined_right_idx,
)
};
Ok(StatefulStreamResult::Continue)
}
fn process_unmatched_build_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let timer = self.join_metrics.join_time.timer();
if !need_produce_result_in_final(self.join_type) {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
}
let build_side = self.build_side.try_as_ready()?;
if self.null_aware
&& build_side
.left_data
.probe_side_has_null
.load(Ordering::Relaxed)
{
timer.done();
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
}
if !build_side.left_data.report_probe_completed() {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
}
let (mut left_side, mut right_side) = get_final_indices_from_shared_bitmap(
build_side.left_data.visited_indices_bitmap(),
self.join_type,
true,
);
if self.null_aware
&& self.join_type == JoinType::LeftAnti
&& build_side
.left_data
.probe_side_non_empty
.load(Ordering::Relaxed)
{
let build_key_column = &build_side.left_data.values()[0];
let filtered_indices: Vec<u64> = left_side
.iter()
.filter_map(|idx| {
let idx_usize = idx.unwrap() as usize;
if build_key_column.is_null(idx_usize) {
None } else {
Some(idx.unwrap())
}
})
.collect();
left_side = UInt64Array::from(filtered_indices);
let mut builder = arrow::array::UInt32Builder::with_capacity(left_side.len());
builder.append_nulls(left_side.len());
right_side = builder.finish();
}
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(left_side.len());
timer.done();
self.state = HashJoinStreamState::Completed;
if !left_side.is_empty() {
let empty_right_batch = RecordBatch::new_empty(self.right.schema());
let batch = build_batch_from_indices(
&self.schema,
build_side.left_data.batch(),
&empty_right_batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
self.join_type,
)?;
let push_status = self.output_buffer.push_batch(batch)?;
if push_status == PushBatchStatus::LimitReached {
self.output_buffer.finish()?;
}
}
Ok(StatefulStreamResult::Continue)
}
}
impl Stream for HashJoinStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.poll_next_impl(cx)
}
}