use std::any::Any;
use std::fmt::Formatter;
use std::ops::{BitOr, ControlFlow};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Poll;
use super::utils::{
asymmetric_join_output_partitioning, need_produce_result_in_final,
reorder_output_after_swap, swap_join_projection,
};
use crate::common::can_project;
use crate::execution_plan::{EmissionType, boundedness_from_children};
use crate::joins::SharedBitmapBuilder;
use crate::joins::utils::{
BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut,
build_join_schema, check_join_is_valid, estimate_join_statistics,
need_produce_right_in_final,
};
use crate::metrics::{
Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, MetricsSet, RatioMetrics,
};
use crate::projection::{
EmbeddedProjection, JoinData, ProjectionExec, try_embed_projection,
try_pushdown_through_join,
};
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
PlanProperties, RecordBatchStream, SendableRecordBatchStream,
check_if_same_properties,
};
use arrow::array::{
Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions, UInt32Array,
UInt64Array, new_null_array,
};
use arrow::buffer::BooleanBuffer;
use arrow::compute::{
BatchCoalescer, concat_batches, filter, filter_record_batch, not, take,
};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::DataType;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::{
JoinSide, Result, ScalarValue, Statistics, arrow_err, assert_eq_or_internal_err,
internal_datafusion_err, internal_err, project_schema, unwrap_or_internal_err,
};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_expr::JoinType;
use datafusion_physical_expr::equivalence::{
ProjectionMapping, join_equivalence_properties,
};
use datafusion_physical_expr::projection::{ProjectionRef, combine_projections};
use futures::{Stream, StreamExt, TryStreamExt};
use log::debug;
use parking_lot::Mutex;
#[expect(rustdoc::private_intra_doc_links)]
#[derive(Debug)]
pub struct NestedLoopJoinExec {
pub(crate) left: Arc<dyn ExecutionPlan>,
pub(crate) right: Arc<dyn ExecutionPlan>,
pub(crate) filter: Option<JoinFilter>,
pub(crate) join_type: JoinType,
join_schema: SchemaRef,
build_side_data: OnceAsync<JoinLeftData>,
column_indices: Vec<ColumnIndex>,
projection: Option<ProjectionRef>,
metrics: ExecutionPlanMetricsSet,
cache: Arc<PlanProperties>,
}
pub struct NestedLoopJoinExecBuilder {
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
join_type: JoinType,
filter: Option<JoinFilter>,
projection: Option<ProjectionRef>,
}
impl NestedLoopJoinExecBuilder {
pub fn new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
join_type: JoinType,
) -> Self {
Self {
left,
right,
join_type,
filter: None,
projection: None,
}
}
pub fn with_projection(self, projection: Option<Vec<usize>>) -> Self {
self.with_projection_ref(projection.map(Into::into))
}
pub fn with_projection_ref(mut self, projection: Option<ProjectionRef>) -> Self {
self.projection = projection;
self
}
pub fn with_filter(mut self, filter: Option<JoinFilter>) -> Self {
self.filter = filter;
self
}
pub fn build(self) -> Result<NestedLoopJoinExec> {
let Self {
left,
right,
join_type,
filter,
projection,
} = self;
let left_schema = left.schema();
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &[])?;
let (join_schema, column_indices) =
build_join_schema(&left_schema, &right_schema, &join_type);
let join_schema = Arc::new(join_schema);
let cache = NestedLoopJoinExec::compute_properties(
&left,
&right,
&join_schema,
join_type,
projection.as_deref(),
)?;
Ok(NestedLoopJoinExec {
left,
right,
filter,
join_type,
join_schema,
build_side_data: Default::default(),
column_indices,
projection,
metrics: Default::default(),
cache: Arc::new(cache),
})
}
}
impl From<&NestedLoopJoinExec> for NestedLoopJoinExecBuilder {
fn from(exec: &NestedLoopJoinExec) -> Self {
Self {
left: Arc::clone(exec.left()),
right: Arc::clone(exec.right()),
join_type: exec.join_type,
filter: exec.filter.clone(),
projection: exec.projection.clone(),
}
}
}
impl NestedLoopJoinExec {
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
filter: Option<JoinFilter>,
join_type: &JoinType,
projection: Option<Vec<usize>>,
) -> Result<Self> {
NestedLoopJoinExecBuilder::new(left, right, *join_type)
.with_projection(projection)
.with_filter(filter)
.build()
}
pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
&self.left
}
pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
&self.right
}
pub fn filter(&self) -> Option<&JoinFilter> {
self.filter.as_ref()
}
pub fn join_type(&self) -> &JoinType {
&self.join_type
}
pub fn projection(&self) -> &Option<ProjectionRef> {
&self.projection
}
fn compute_properties(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
schema: &SchemaRef,
join_type: JoinType,
projection: Option<&[usize]>,
) -> Result<PlanProperties> {
let mut eq_properties = join_equivalence_properties(
left.equivalence_properties().clone(),
right.equivalence_properties().clone(),
&join_type,
Arc::clone(schema),
&Self::maintains_input_order(join_type),
None,
&[],
)?;
let mut output_partitioning =
asymmetric_join_output_partitioning(left, right, &join_type)?;
let emission_type = if left.boundedness().is_unbounded() {
EmissionType::Final
} else if right.pipeline_behavior() == EmissionType::Incremental {
match join_type {
JoinType::Inner
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::Right
| JoinType::RightAnti
| JoinType::RightMark => EmissionType::Incremental,
JoinType::Left
| JoinType::LeftAnti
| JoinType::LeftMark
| JoinType::Full => EmissionType::Both,
}
} else {
right.pipeline_behavior()
};
if let Some(projection) = projection {
let projection_mapping = ProjectionMapping::from_indices(projection, schema)?;
let out_schema = project_schema(schema, Some(&projection))?;
output_partitioning =
output_partitioning.project(&projection_mapping, &eq_properties);
eq_properties = eq_properties.project(&projection_mapping, out_schema);
}
Ok(PlanProperties::new(
eq_properties,
output_partitioning,
emission_type,
boundedness_from_children([left, right]),
))
}
fn maintains_input_order(_join_type: JoinType) -> Vec<bool> {
vec![false, false]
}
pub fn contains_projection(&self) -> bool {
self.projection.is_some()
}
pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
let projection = projection.map(Into::into);
can_project(&self.schema(), projection.as_deref())?;
let projection =
combine_projections(projection.as_ref(), self.projection.as_ref())?;
NestedLoopJoinExecBuilder::from(self)
.with_projection_ref(projection)
.build()
}
pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
let left = self.left();
let right = self.right();
let new_join = NestedLoopJoinExec::try_new(
Arc::clone(right),
Arc::clone(left),
self.filter().map(JoinFilter::swap),
&self.join_type().swap(),
swap_join_projection(
left.schema().fields().len(),
right.schema().fields().len(),
self.projection.as_deref(),
self.join_type(),
),
)?;
let plan: Arc<dyn ExecutionPlan> = if matches!(
self.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
) || self.projection.is_some()
{
Arc::new(new_join)
} else {
reorder_output_after_swap(
Arc::new(new_join),
&self.left().schema(),
&self.right().schema(),
)?
};
Ok(plan)
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
let left = children.swap_remove(0);
let right = children.swap_remove(0);
Self {
left,
right,
metrics: ExecutionPlanMetricsSet::new(),
build_side_data: Default::default(),
cache: Arc::clone(&self.cache),
filter: self.filter.clone(),
join_type: self.join_type,
join_schema: Arc::clone(&self.join_schema),
column_indices: self.column_indices.clone(),
projection: self.projection.clone(),
}
}
}
impl DisplayAs for NestedLoopJoinExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let display_filter = self.filter.as_ref().map_or_else(
|| "".to_string(),
|f| format!(", filter={}", f.expression()),
);
let display_projections = if self.contains_projection() {
format!(
", projection=[{}]",
self.projection
.as_ref()
.unwrap()
.iter()
.map(|index| format!(
"{}@{}",
self.join_schema.fields().get(*index).unwrap().name(),
index
))
.collect::<Vec<_>>()
.join(", ")
)
} else {
"".to_string()
};
write!(
f,
"NestedLoopJoinExec: join_type={:?}{}{}",
self.join_type, display_filter, display_projections
)
}
DisplayFormatType::TreeRender => {
if *self.join_type() != JoinType::Inner {
writeln!(f, "join_type={:?}", self.join_type)
} else {
Ok(())
}
}
}
}
}
impl ExecutionPlan for NestedLoopJoinExec {
fn name(&self) -> &'static str {
"NestedLoopJoinExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn required_input_distribution(&self) -> Vec<Distribution> {
vec![
Distribution::SinglePartition,
Distribution::UnspecifiedDistribution,
]
}
fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order(self.join_type)
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
Ok(Arc::new(
NestedLoopJoinExecBuilder::new(
Arc::clone(&children[0]),
Arc::clone(&children[1]),
self.join_type,
)
.with_filter(self.filter.clone())
.with_projection_ref(self.projection.clone())
.build()?,
))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq_or_internal_err!(
self.left.output_partitioning().partition_count(),
1,
"Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
consider using CoalescePartitionsExec or the EnforceDistribution rule"
);
let metrics = NestedLoopJoinMetrics::new(&self.metrics, partition);
let load_reservation =
MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
.register(context.memory_pool());
let build_side_data = self.build_side_data.try_once(|| {
let stream = self.left.execute(0, Arc::clone(&context))?;
Ok(collect_left_input(
stream,
metrics.join_metrics.clone(),
load_reservation,
need_produce_result_in_final(self.join_type),
self.right().output_partitioning().partition_count(),
))
})?;
let batch_size = context.session_config().batch_size();
let probe_side_data = self.right.execute(partition, context)?;
let column_indices_after_projection = match self.projection.as_ref() {
Some(projection) => projection
.iter()
.map(|i| self.column_indices[*i].clone())
.collect(),
None => self.column_indices.clone(),
};
Ok(Box::pin(NestedLoopJoinStream::new(
self.schema(),
self.filter.clone(),
self.join_type,
probe_side_data,
build_side_data,
column_indices_after_projection,
metrics,
batch_size,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
let join_columns = Vec::new();
let left_stats = self.left.partition_statistics(None)?;
let right_stats = match partition {
Some(partition) => self.right.partition_statistics(Some(partition))?,
None => self.right.partition_statistics(None)?,
};
let stats = estimate_join_statistics(
left_stats,
right_stats,
&join_columns,
&self.join_type,
&self.join_schema,
)?;
Ok(stats.project(self.projection.as_ref()))
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
if self.contains_projection() {
return Ok(None);
}
let schema = self.schema();
if let Some(JoinData {
projected_left_child,
projected_right_child,
join_filter,
..
}) = try_pushdown_through_join(
projection,
self.left(),
self.right(),
&[],
&schema,
self.filter(),
)? {
Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
Arc::new(projected_left_child),
Arc::new(projected_right_child),
join_filter,
self.join_type(),
None,
)?)))
} else {
try_embed_projection(projection, self)
}
}
}
impl EmbeddedProjection for NestedLoopJoinExec {
fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
self.with_projection(projection)
}
}
pub(crate) struct JoinLeftData {
batch: RecordBatch,
bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
#[expect(dead_code)]
reservation: MemoryReservation,
}
impl JoinLeftData {
pub(crate) fn new(
batch: RecordBatch,
bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
reservation: MemoryReservation,
) -> Self {
Self {
batch,
bitmap,
probe_threads_counter,
reservation,
}
}
pub(crate) fn batch(&self) -> &RecordBatch {
&self.batch
}
pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder {
&self.bitmap
}
pub(crate) fn report_probe_completed(&self) -> bool {
self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
}
}
async fn collect_left_input(
stream: SendableRecordBatchStream,
join_metrics: BuildProbeJoinMetrics,
reservation: MemoryReservation,
with_visited_left_side: bool,
probe_threads_count: usize,
) -> Result<JoinLeftData> {
let schema = stream.schema();
let (batches, metrics, reservation) = stream
.try_fold(
(Vec::new(), join_metrics, reservation),
|(mut batches, metrics, reservation), batch| async {
let batch_size = batch.get_array_memory_size();
reservation.try_grow(batch_size)?;
metrics.build_mem_used.add(batch_size);
metrics.build_input_batches.add(1);
metrics.build_input_rows.add(batch.num_rows());
batches.push(batch);
Ok((batches, metrics, reservation))
},
)
.await?;
let merged_batch = concat_batches(&schema, &batches)?;
let visited_left_side = if with_visited_left_side {
let n_rows = merged_batch.num_rows();
let buffer_size = n_rows.div_ceil(8);
reservation.try_grow(buffer_size)?;
metrics.build_mem_used.add(buffer_size);
let mut buffer = BooleanBufferBuilder::new(n_rows);
buffer.append_n(n_rows, false);
buffer
} else {
BooleanBufferBuilder::new(0)
};
Ok(JoinLeftData::new(
merged_batch,
Mutex::new(visited_left_side),
AtomicUsize::new(probe_threads_count),
reservation,
))
}
#[derive(Debug, Clone, Copy)]
enum NLJState {
BufferingLeft,
FetchingRight,
ProbeRight,
EmitRightUnmatched,
EmitLeftUnmatched,
Done,
}
pub(crate) struct NestedLoopJoinStream {
pub(crate) output_schema: Arc<Schema>,
pub(crate) join_filter: Option<JoinFilter>,
pub(crate) join_type: JoinType,
pub(crate) right_data: SendableRecordBatchStream,
pub(crate) left_data: OnceFut<JoinLeftData>,
pub(crate) column_indices: Vec<ColumnIndex>,
pub(crate) metrics: NestedLoopJoinMetrics,
batch_size: usize,
should_track_unmatched_right: bool,
state: NLJState,
output_buffer: Box<BatchCoalescer>,
handled_empty_output: bool,
buffered_left_data: Option<Arc<JoinLeftData>>,
left_probe_idx: usize,
left_emit_idx: usize,
left_exhausted: bool,
#[expect(dead_code)]
left_buffered_in_one_pass: bool,
current_right_batch: Option<RecordBatch>,
current_right_batch_matched: Option<BooleanArray>,
}
pub(crate) struct NestedLoopJoinMetrics {
pub(crate) join_metrics: BuildProbeJoinMetrics,
pub(crate) selectivity: RatioMetrics,
}
impl NestedLoopJoinMetrics {
pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
Self {
join_metrics: BuildProbeJoinMetrics::new(partition, metrics),
selectivity: MetricBuilder::new(metrics)
.with_type(MetricType::SUMMARY)
.ratio_metrics("selectivity", partition),
}
}
}
impl Stream for NestedLoopJoinStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
match self.state {
NLJState::BufferingLeft => {
debug!("[NLJState] Entering: {:?}", self.state);
let build_metric = self.metrics.join_metrics.build_time.clone();
let _build_timer = build_metric.timer();
match self.handle_buffering_left(cx) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(poll) => return poll,
}
}
NLJState::FetchingRight => {
debug!("[NLJState] Entering: {:?}", self.state);
let join_metric = self.metrics.join_metrics.join_time.clone();
let _join_timer = join_metric.timer();
match self.handle_fetching_right(cx) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(poll) => return poll,
}
}
NLJState::ProbeRight => {
debug!("[NLJState] Entering: {:?}", self.state);
let join_metric = self.metrics.join_metrics.join_time.clone();
let _join_timer = join_metric.timer();
match self.handle_probe_right() {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(poll) => {
return self.metrics.join_metrics.baseline.record_poll(poll);
}
}
}
NLJState::EmitRightUnmatched => {
debug!("[NLJState] Entering: {:?}", self.state);
let join_metric = self.metrics.join_metrics.join_time.clone();
let _join_timer = join_metric.timer();
match self.handle_emit_right_unmatched() {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(poll) => {
return self.metrics.join_metrics.baseline.record_poll(poll);
}
}
}
NLJState::EmitLeftUnmatched => {
debug!("[NLJState] Entering: {:?}", self.state);
let join_metric = self.metrics.join_metrics.join_time.clone();
let _join_timer = join_metric.timer();
match self.handle_emit_left_unmatched() {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(poll) => {
return self.metrics.join_metrics.baseline.record_poll(poll);
}
}
}
NLJState::Done => {
debug!("[NLJState] Entering: {:?}", self.state);
let join_metric = self.metrics.join_metrics.join_time.clone();
let _join_timer = join_metric.timer();
let poll = self.handle_done();
return self.metrics.join_metrics.baseline.record_poll(poll);
}
}
}
}
}
impl RecordBatchStream for NestedLoopJoinStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.output_schema)
}
}
impl NestedLoopJoinStream {
#[expect(clippy::too_many_arguments)]
pub(crate) fn new(
schema: Arc<Schema>,
filter: Option<JoinFilter>,
join_type: JoinType,
right_data: SendableRecordBatchStream,
left_data: OnceFut<JoinLeftData>,
column_indices: Vec<ColumnIndex>,
metrics: NestedLoopJoinMetrics,
batch_size: usize,
) -> Self {
Self {
output_schema: Arc::clone(&schema),
join_filter: filter,
join_type,
right_data,
column_indices,
left_data,
metrics,
buffered_left_data: None,
output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)),
batch_size,
current_right_batch: None,
current_right_batch_matched: None,
state: NLJState::BufferingLeft,
left_probe_idx: 0,
left_emit_idx: 0,
left_exhausted: false,
left_buffered_in_one_pass: true,
handled_empty_output: false,
should_track_unmatched_right: need_produce_right_in_final(join_type),
}
}
fn handle_buffering_left(
&mut self,
cx: &mut std::task::Context<'_>,
) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
match self.left_data.get_shared(cx) {
Poll::Ready(Ok(left_data)) => {
self.buffered_left_data = Some(left_data);
self.left_exhausted = true;
self.state = NLJState::FetchingRight;
ControlFlow::Continue(())
}
Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
Poll::Pending => ControlFlow::Break(Poll::Pending),
}
}
fn handle_fetching_right(
&mut self,
cx: &mut std::task::Context<'_>,
) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
match self.right_data.poll_next_unpin(cx) {
Poll::Ready(result) => match result {
Some(Ok(right_batch)) => {
let right_batch_size = right_batch.num_rows();
self.metrics.join_metrics.input_rows.add(right_batch_size);
self.metrics.join_metrics.input_batches.add(1);
if right_batch_size == 0 {
return ControlFlow::Continue(());
}
self.current_right_batch = Some(right_batch);
if self.should_track_unmatched_right {
let zeroed_buf = BooleanBuffer::new_unset(right_batch_size);
self.current_right_batch_matched =
Some(BooleanArray::new(zeroed_buf, None));
}
self.left_probe_idx = 0;
self.state = NLJState::ProbeRight;
ControlFlow::Continue(())
}
Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
None => {
self.state = NLJState::EmitLeftUnmatched;
ControlFlow::Continue(())
}
},
Poll::Pending => ControlFlow::Break(Poll::Pending),
}
}
fn handle_probe_right(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
if let Some(poll) = self.maybe_flush_ready_batch() {
return ControlFlow::Break(poll);
}
match self.process_probe_batch() {
Ok(true) => ControlFlow::Continue(()),
Ok(false) => {
self.left_probe_idx = 0;
if let (Ok(left_data), Some(right_batch)) =
(self.get_left_data(), self.current_right_batch.as_ref())
{
let left_rows = left_data.batch().num_rows();
let right_rows = right_batch.num_rows();
self.metrics.selectivity.add_total(left_rows * right_rows);
}
if self.should_track_unmatched_right {
debug_assert!(
self.current_right_batch_matched.is_some(),
"If it's required to track matched rows in the right input, the right bitmap must be present"
);
self.state = NLJState::EmitRightUnmatched;
} else {
self.current_right_batch = None;
self.state = NLJState::FetchingRight;
}
ControlFlow::Continue(())
}
Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
}
}
fn handle_emit_right_unmatched(
&mut self,
) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
if let Some(poll) = self.maybe_flush_ready_batch() {
return ControlFlow::Break(poll);
}
debug_assert!(
self.current_right_batch_matched.is_some()
&& self.current_right_batch.is_some(),
"This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present"
);
match self.process_right_unmatched() {
Ok(Some(batch)) => {
match self.output_buffer.push_batch(batch) {
Ok(()) => {
debug_assert!(self.current_right_batch.is_none());
self.state = NLJState::FetchingRight;
ControlFlow::Continue(())
}
Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
}
}
Ok(None) => {
debug_assert!(self.current_right_batch.is_none());
self.state = NLJState::FetchingRight;
ControlFlow::Continue(())
}
Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
}
}
fn handle_emit_left_unmatched(
&mut self,
) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
if let Some(poll) = self.maybe_flush_ready_batch() {
return ControlFlow::Break(poll);
}
match self.process_left_unmatched() {
Ok(true) => ControlFlow::Continue(()),
Ok(false) => match self.output_buffer.finish_buffered_batch() {
Ok(()) => {
self.state = NLJState::Done;
ControlFlow::Continue(())
}
Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
},
Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
}
}
fn handle_done(&mut self) -> Poll<Option<Result<RecordBatch>>> {
if let Some(poll) = self.maybe_flush_ready_batch() {
return poll;
}
if !self.handled_empty_output {
let zero_count = Count::new();
if *self.metrics.join_metrics.baseline.output_rows() == zero_count {
let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema));
self.handled_empty_output = true;
return Poll::Ready(Some(Ok(empty_batch)));
}
}
Poll::Ready(None)
}
fn process_probe_batch(&mut self) -> Result<bool> {
let left_data = Arc::clone(self.get_left_data()?);
let right_batch = self
.current_right_batch
.as_ref()
.ok_or_else(|| internal_datafusion_err!("Right batch should be available"))?
.clone();
if self.left_probe_idx >= left_data.batch().num_rows() {
return Ok(false);
}
debug_assert_ne!(
right_batch.num_rows(),
0,
"When fetching the right batch, empty batches will be skipped"
);
let l_row_cnt_ratio = self.batch_size / right_batch.num_rows();
if l_row_cnt_ratio > 10 {
let l_row_count = std::cmp::min(
l_row_cnt_ratio,
left_data.batch().num_rows() - self.left_probe_idx,
);
debug_assert!(
l_row_count != 0,
"This function should only be entered when there are remaining left rows to process"
);
let joined_batch = self.process_left_range_join(
&left_data,
&right_batch,
self.left_probe_idx,
l_row_count,
)?;
if let Some(batch) = joined_batch {
self.output_buffer.push_batch(batch)?;
}
self.left_probe_idx += l_row_count;
return Ok(true);
}
let l_idx = self.left_probe_idx;
let joined_batch =
self.process_single_left_row_join(&left_data, &right_batch, l_idx)?;
if let Some(batch) = joined_batch {
self.output_buffer.push_batch(batch)?;
}
self.left_probe_idx += 1;
Ok(true)
}
fn process_left_range_join(
&mut self,
left_data: &JoinLeftData,
right_batch: &RecordBatch,
l_start_index: usize,
l_row_count: usize,
) -> Result<Option<RecordBatch>> {
let right_rows = right_batch.num_rows();
let total_rows = l_row_count * right_rows;
let left_indices: UInt32Array =
UInt32Array::from_iter_values((0..l_row_count).flat_map(|i| {
std::iter::repeat_n((l_start_index + i) as u32, right_rows)
}));
let right_indices: UInt32Array = UInt32Array::from_iter_values(
(0..l_row_count).flat_map(|_| 0..right_rows as u32),
);
debug_assert!(
left_indices.len() == right_indices.len()
&& right_indices.len() == total_rows,
"The length or cartesian product should be (left_size * right_size)",
);
let bitmap_combined = if let Some(filter) = &self.join_filter {
let intermediate_batch = if filter.schema.fields().is_empty() {
create_record_batch_with_empty_schema(
Arc::new((*filter.schema).clone()),
total_rows,
)?
} else {
let mut filter_columns: Vec<Arc<dyn Array>> =
Vec::with_capacity(filter.column_indices().len());
for column_index in filter.column_indices() {
let array = if column_index.side == JoinSide::Left {
let col = left_data.batch().column(column_index.index);
take(col.as_ref(), &left_indices, None)?
} else {
let col = right_batch.column(column_index.index);
take(col.as_ref(), &right_indices, None)?
};
filter_columns.push(array);
}
RecordBatch::try_new(Arc::new((*filter.schema).clone()), filter_columns)?
};
let filter_result = filter
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows())?;
let filter_arr = as_boolean_array(&filter_result)?;
boolean_mask_from_filter(filter_arr)
} else {
BooleanArray::from(vec![true; total_rows])
};
let mut left_bitmap = if need_produce_result_in_final(self.join_type) {
Some(left_data.bitmap().lock())
} else {
None
};
let mut local_right_bitmap = if self.should_track_unmatched_right {
let mut current_right_batch_bitmap = BooleanBufferBuilder::new(right_rows);
current_right_batch_bitmap.append_n(right_rows, false);
Some(current_right_batch_bitmap)
} else {
None
};
for (i, is_matched) in bitmap_combined.iter().enumerate() {
let is_matched = is_matched.ok_or_else(|| {
internal_datafusion_err!("Must be Some after the previous combining step")
})?;
let l_index = l_start_index + i / right_rows;
let r_index = i % right_rows;
if let Some(bitmap) = left_bitmap.as_mut()
&& is_matched
{
bitmap.set_bit(l_index, true);
}
if let Some(bitmap) = local_right_bitmap.as_mut()
&& is_matched
{
bitmap.set_bit(r_index, true);
}
}
if self.should_track_unmatched_right {
let global_right_bitmap =
std::mem::take(&mut self.current_right_batch_matched).ok_or_else(
|| internal_datafusion_err!("right batch's bitmap should be present"),
)?;
let (buf, nulls) = global_right_bitmap.into_parts();
debug_assert!(nulls.is_none());
let current_right_bitmap = local_right_bitmap
.ok_or_else(|| {
internal_datafusion_err!(
"Should be Some if the current join type requires right bitmap"
)
})?
.finish();
let updated_global_right_bitmap = buf.bitor(¤t_right_bitmap);
self.current_right_batch_matched =
Some(BooleanArray::new(updated_global_right_bitmap, None));
}
if matches!(
self.join_type,
JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::LeftMark
| JoinType::RightAnti
| JoinType::RightMark
| JoinType::RightSemi
) {
return Ok(None);
}
if self.output_schema.fields().is_empty() {
let row_count = bitmap_combined.true_count();
return Ok(Some(create_record_batch_with_empty_schema(
Arc::clone(&self.output_schema),
row_count,
)?));
}
let mut out_columns: Vec<Arc<dyn Array>> =
Vec::with_capacity(self.output_schema.fields().len());
for column_index in &self.column_indices {
let array = if column_index.side == JoinSide::Left {
let col = left_data.batch().column(column_index.index);
take(col.as_ref(), &left_indices, None)?
} else {
let col = right_batch.column(column_index.index);
take(col.as_ref(), &right_indices, None)?
};
out_columns.push(array);
}
let pre_filtered =
RecordBatch::try_new(Arc::clone(&self.output_schema), out_columns)?;
let filtered = filter_record_batch(&pre_filtered, &bitmap_combined)?;
Ok(Some(filtered))
}
fn process_single_left_row_join(
&mut self,
left_data: &JoinLeftData,
right_batch: &RecordBatch,
l_index: usize,
) -> Result<Option<RecordBatch>> {
let right_row_count = right_batch.num_rows();
if right_row_count == 0 {
return Ok(None);
}
let cur_right_bitmap = if let Some(filter) = &self.join_filter {
apply_filter_to_row_join_batch(
left_data.batch(),
l_index,
right_batch,
filter,
)?
} else {
BooleanArray::from(vec![true; right_row_count])
};
self.update_matched_bitmap(l_index, &cur_right_bitmap)?;
if matches!(
self.join_type,
JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::LeftMark
| JoinType::RightAnti
| JoinType::RightMark
| JoinType::RightSemi
) {
return Ok(None);
}
if cur_right_bitmap.true_count() == 0 {
Ok(None)
} else {
let join_batch = build_row_join_batch(
&self.output_schema,
left_data.batch(),
l_index,
right_batch,
Some(cur_right_bitmap),
&self.column_indices,
JoinSide::Left,
)?;
Ok(join_batch)
}
}
fn process_left_unmatched(&mut self) -> Result<bool> {
let left_data = self.get_left_data()?;
let left_batch = left_data.batch();
let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
let handled_by_other_partition =
self.left_emit_idx == 0 && !left_data.report_probe_completed();
let finished = self.left_emit_idx >= left_batch.num_rows();
if join_type_no_produce_left || handled_by_other_partition || finished {
return Ok(false);
}
let start_idx = self.left_emit_idx;
let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
if let Some(batch) =
self.process_left_unmatched_range(left_data, start_idx, end_idx)?
{
self.output_buffer.push_batch(batch)?;
}
self.left_emit_idx = end_idx;
Ok(true)
}
fn process_left_unmatched_range(
&self,
left_data: &JoinLeftData,
start_idx: usize,
end_idx: usize,
) -> Result<Option<RecordBatch>> {
if start_idx == end_idx {
return Ok(None);
}
let left_batch = left_data.batch();
let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx);
let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx);
bitmap_sliced.append_n(end_idx - start_idx, false);
let bitmap = left_data.bitmap().lock();
for i in start_idx..end_idx {
assert!(
i - start_idx < bitmap_sliced.capacity(),
"DBG: {start_idx}, {end_idx}"
);
bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i));
}
let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None);
let right_schema = self.right_data.schema();
build_unmatched_batch(
&self.output_schema,
&left_batch_sliced,
bitmap_sliced,
&right_schema,
&self.column_indices,
self.join_type,
JoinSide::Left,
)
}
fn process_right_unmatched(&mut self) -> Result<Option<RecordBatch>> {
let right_batch_bitmap: BooleanArray =
std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| {
internal_datafusion_err!("right bitmap should be available")
})?;
let right_batch = self.current_right_batch.take();
let cur_right_batch = unwrap_or_internal_err!(right_batch);
let left_data = self.get_left_data()?;
let left_schema = left_data.batch().schema();
let res = build_unmatched_batch(
&self.output_schema,
&cur_right_batch,
right_batch_bitmap,
&left_schema,
&self.column_indices,
self.join_type,
JoinSide::Right,
);
self.current_right_batch_matched = None;
res
}
fn get_left_data(&self) -> Result<&Arc<JoinLeftData>> {
self.buffered_left_data
.as_ref()
.ok_or_else(|| internal_datafusion_err!("LeftData should be available"))
}
fn maybe_flush_ready_batch(&mut self) -> Option<Poll<Option<Result<RecordBatch>>>> {
if self.output_buffer.has_completed_batch()
&& let Some(batch) = self.output_buffer.next_completed_batch()
{
let output_rows = batch.num_rows();
self.metrics.selectivity.add_part(output_rows);
return Some(Poll::Ready(Some(Ok(batch))));
}
None
}
fn update_matched_bitmap(
&mut self,
l_index: usize,
r_matched_bitmap: &BooleanArray,
) -> Result<()> {
let left_data = self.get_left_data()?;
let joined_len = r_matched_bitmap.true_count();
if need_produce_result_in_final(self.join_type) && (joined_len > 0) {
let mut bitmap = left_data.bitmap().lock();
bitmap.set_bit(l_index, true);
}
if self.should_track_unmatched_right {
debug_assert!(self.current_right_batch_matched.is_some());
let right_bitmap = std::mem::take(&mut self.current_right_batch_matched)
.ok_or_else(|| {
internal_datafusion_err!("right batch's bitmap should be present")
})?;
let (buf, nulls) = right_bitmap.into_parts();
debug_assert!(nulls.is_none());
let updated_right_bitmap = buf.bitor(r_matched_bitmap.values());
self.current_right_batch_matched =
Some(BooleanArray::new(updated_right_bitmap, None));
}
Ok(())
}
}
fn apply_filter_to_row_join_batch(
left_batch: &RecordBatch,
l_index: usize,
right_batch: &RecordBatch,
filter: &JoinFilter,
) -> Result<BooleanArray> {
debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0);
let intermediate_batch = if filter.schema.fields().is_empty() {
create_record_batch_with_empty_schema(
Arc::new((*filter.schema).clone()),
right_batch.num_rows(),
)?
} else {
build_row_join_batch(
&filter.schema,
left_batch,
l_index,
right_batch,
None,
&filter.column_indices,
JoinSide::Left,
)?
.ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))?
};
let filter_result = filter
.expression()
.evaluate(&intermediate_batch)?
.into_array(intermediate_batch.num_rows())?;
let filter_arr = as_boolean_array(&filter_result)?;
let bitmap_combined = boolean_mask_from_filter(filter_arr);
Ok(bitmap_combined)
}
#[inline]
fn boolean_mask_from_filter(filter_arr: &BooleanArray) -> BooleanArray {
let (values, nulls) = filter_arr.clone().into_parts();
match nulls {
Some(nulls) => BooleanArray::new(nulls.inner() & &values, None),
None => BooleanArray::new(values, None),
}
}
fn build_row_join_batch(
output_schema: &Schema,
build_side_batch: &RecordBatch,
build_side_index: usize,
probe_side_batch: &RecordBatch,
probe_side_filter: Option<BooleanArray>,
col_indices: &[ColumnIndex],
build_side: JoinSide,
) -> Result<Option<RecordBatch>> {
debug_assert!(build_side != JoinSide::None);
let filtered_probe_batch = if let Some(filter) = probe_side_filter {
&filter_record_batch(probe_side_batch, &filter)?
} else {
probe_side_batch
};
if filtered_probe_batch.num_rows() == 0 {
return Ok(None);
}
if output_schema.fields.is_empty() {
return Ok(Some(create_record_batch_with_empty_schema(
Arc::new(output_schema.clone()),
filtered_probe_batch.num_rows(),
)?));
}
let mut columns: Vec<Arc<dyn Array>> =
Vec::with_capacity(output_schema.fields().len());
for column_index in col_indices {
let array = if column_index.side == build_side {
let original_left_array = build_side_batch.column(column_index.index);
match original_left_array.data_type() {
DataType::List(field) | DataType::LargeList(field)
if field.data_type() == &DataType::Utf8View =>
{
let indices_iter = std::iter::repeat_n(
build_side_index as u64,
filtered_probe_batch.num_rows(),
);
let indices_array = UInt64Array::from_iter_values(indices_iter);
take(original_left_array.as_ref(), &indices_array, None)?
}
_ => {
let scalar_value = ScalarValue::try_from_array(
original_left_array.as_ref(),
build_side_index,
)?;
scalar_value.to_array_of_size(filtered_probe_batch.num_rows())?
}
}
} else {
Arc::clone(filtered_probe_batch.column(column_index.index))
};
columns.push(array);
}
Ok(Some(RecordBatch::try_new(
Arc::new(output_schema.clone()),
columns,
)?))
}
fn build_unmatched_batch_empty_schema(
output_schema: &SchemaRef,
batch_bitmap: &BooleanArray,
join_type: JoinType,
) -> Result<Option<RecordBatch>> {
let result_size = match join_type {
JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftAnti
| JoinType::RightAnti => batch_bitmap.false_count(),
JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(),
JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(),
_ => unreachable!(),
};
if output_schema.fields().is_empty() {
Ok(Some(create_record_batch_with_empty_schema(
Arc::clone(output_schema),
result_size,
)?))
} else {
Ok(None)
}
}
fn create_record_batch_with_empty_schema(
schema: SchemaRef,
row_count: usize,
) -> Result<RecordBatch> {
let options = RecordBatchOptions::new()
.with_match_field_names(true)
.with_row_count(Some(row_count));
RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| {
internal_datafusion_err!("Failed to create empty record batch: {}", e)
})
}
fn build_unmatched_batch(
output_schema: &SchemaRef,
batch: &RecordBatch,
batch_bitmap: BooleanArray,
another_side_schema: &SchemaRef,
col_indices: &[ColumnIndex],
join_type: JoinType,
batch_side: JoinSide,
) -> Result<Option<RecordBatch>> {
debug_assert_ne!(join_type, JoinType::Inner);
debug_assert_ne!(batch_side, JoinSide::None);
if let Some(batch) =
build_unmatched_batch_empty_schema(output_schema, &batch_bitmap, join_type)?
{
return Ok(Some(batch));
}
match join_type {
JoinType::Full | JoinType::Right | JoinType::Left => {
if join_type == JoinType::Right {
debug_assert_eq!(batch_side, JoinSide::Right);
}
if join_type == JoinType::Left {
debug_assert_eq!(batch_side, JoinSide::Left);
}
let flipped_bitmap = not(&batch_bitmap)?;
let left_null_columns: Vec<Arc<dyn Array>> = another_side_schema
.fields()
.iter()
.map(|field| new_null_array(field.data_type(), 1))
.collect();
let nullable_left_schema = Arc::new(Schema::new(
another_side_schema
.fields()
.iter()
.map(|field| (**field).clone().with_nullable(true))
.collect::<Vec<_>>(),
));
let left_null_batch = if nullable_left_schema.fields.is_empty() {
create_record_batch_with_empty_schema(nullable_left_schema, 0)?
} else {
RecordBatch::try_new(nullable_left_schema, left_null_columns)?
};
debug_assert_ne!(batch_side, JoinSide::None);
let opposite_side = batch_side.negate();
build_row_join_batch(
output_schema,
&left_null_batch,
0,
batch,
Some(flipped_bitmap),
col_indices,
opposite_side,
)
}
JoinType::RightSemi
| JoinType::RightAnti
| JoinType::LeftSemi
| JoinType::LeftAnti => {
if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) {
debug_assert_eq!(batch_side, JoinSide::Right);
}
if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
debug_assert_eq!(batch_side, JoinSide::Left);
}
let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi)
{
batch_bitmap.clone()
} else {
not(&batch_bitmap)?
};
if bitmap.true_count() == 0 {
return Ok(None);
}
let mut columns: Vec<Arc<dyn Array>> =
Vec::with_capacity(output_schema.fields().len());
for column_index in col_indices {
debug_assert!(column_index.side == batch_side);
let col = batch.column(column_index.index);
let filtered_col = filter(col, &bitmap)?;
columns.push(filtered_col);
}
Ok(Some(RecordBatch::try_new(
Arc::clone(output_schema),
columns,
)?))
}
JoinType::RightMark | JoinType::LeftMark => {
if join_type == JoinType::RightMark {
debug_assert_eq!(batch_side, JoinSide::Right);
}
if join_type == JoinType::LeftMark {
debug_assert_eq!(batch_side, JoinSide::Left);
}
let mut columns: Vec<Arc<dyn Array>> =
Vec::with_capacity(output_schema.fields().len());
let mut right_batch_bitmap_opt = Some(batch_bitmap);
for column_index in col_indices {
if column_index.side == batch_side {
let col = batch.column(column_index.index);
columns.push(Arc::clone(col));
} else if column_index.side == JoinSide::None {
let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt);
match right_batch_bitmap {
Some(right_batch_bitmap) => {
columns.push(Arc::new(right_batch_bitmap))
}
None => unreachable!("Should only be one mark column"),
}
} else {
return internal_err!(
"Not possible to have this join side for RightMark join"
);
}
}
Ok(Some(RecordBatch::try_new(
Arc::clone(output_schema),
columns,
)?))
}
_ => internal_err!(
"If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"
),
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::test::{TestMemoryExec, assert_join_metrics};
use crate::{
common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field};
use datafusion_common::test_util::batches_to_sort_string;
use datafusion_common::{ScalarValue, assert_contains};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use insta::allow_duplicates;
use insta::assert_snapshot;
use rstest::rstest;
fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
batch_size: Option<usize>,
sorted_column_names: Vec<&str>,
) -> Arc<dyn ExecutionPlan> {
let batch = build_table_i32(a, b, c);
let schema = batch.schema();
let batches = if let Some(batch_size) = batch_size {
let num_batches = batch.num_rows().div_ceil(batch_size);
(0..num_batches)
.map(|i| {
let start = i * batch_size;
let remaining_rows = batch.num_rows() - start;
batch.slice(start, batch_size.min(remaining_rows))
})
.collect::<Vec<_>>()
} else {
vec![batch]
};
let mut sort_info = vec![];
for name in sorted_column_names {
let index = schema.index_of(name).unwrap();
let sort_expr = PhysicalSortExpr::new(
Arc::new(Column::new(name, index)),
SortOptions::new(false, false),
);
sort_info.push(sort_expr);
}
let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap();
if let Some(ordering) = LexOrdering::new(sort_info) {
source = source.try_with_sort_information(vec![ordering]).unwrap();
}
let source = Arc::new(source);
Arc::new(TestMemoryExec::update_cache(&source))
}
fn build_left_table() -> Arc<dyn ExecutionPlan> {
build_table(
("a1", &vec![5, 9, 11]),
("b1", &vec![5, 8, 8]),
("c1", &vec![50, 90, 110]),
None,
Vec::new(),
)
}
fn build_right_table() -> Arc<dyn ExecutionPlan> {
build_table(
("a2", &vec![12, 2, 10]),
("b2", &vec![10, 2, 10]),
("c2", &vec![40, 80, 100]),
None,
Vec::new(),
)
}
fn prepare_join_filter() -> JoinFilter {
let column_indices = vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("x", DataType::Int32, true),
Field::new("x", DataType::Int32, true),
]);
let left_filter = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 0)),
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
)) as Arc<dyn PhysicalExpr>;
let right_filter = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 1)),
Operator::NotEq,
Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
)) as Arc<dyn PhysicalExpr>;
let filter_expression =
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;
JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema),
)
}
pub(crate) async fn multi_partitioned_join_collect(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
join_type: &JoinType,
join_filter: Option<JoinFilter>,
context: Arc<TaskContext>,
) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
let partition_count = 4;
let right = Arc::new(RepartitionExec::try_new(
right,
Partitioning::RoundRobinBatch(partition_count),
)?) as Arc<dyn ExecutionPlan>;
let nested_loop_join =
NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
let columns = columns(&nested_loop_join.schema());
let mut batches = vec![];
for i in 0..partition_count {
let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
let more_batches = common::collect(stream).await?;
batches.extend(
more_batches
.into_iter()
.inspect(|b| {
assert!(b.num_rows() <= context.session_config().batch_size())
})
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>(),
);
}
let metrics = nested_loop_join.metrics().unwrap();
Ok((columns, batches, metrics))
}
fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
let base = TaskContext::default();
let cfg = base.session_config().clone().with_batch_size(batch_size);
Arc::new(base.with_session_config(cfg))
}
#[rstest]
#[tokio::test]
async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
dbg!(&batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::Inner,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+----+----+----+----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+----+----+----+----+
| 5 | 5 | 50 | 2 | 2 | 80 |
+----+----+----+----+----+----+
"));
assert_join_metrics!(metrics, 1);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::Left,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+-----+----+----+----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+-----+----+----+----+
| 11 | 8 | 110 | | | |
| 5 | 5 | 50 | 2 | 2 | 80 |
| 9 | 8 | 90 | | | |
+----+----+-----+----+----+----+
"));
assert_join_metrics!(metrics, 3);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::Right,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+----+----+----+-----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+----+----+----+-----+
| | | | 10 | 10 | 100 |
| | | | 12 | 10 | 40 |
| 5 | 5 | 50 | 2 | 2 | 80 |
+----+----+----+----+----+-----+
"));
assert_join_metrics!(metrics, 3);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::Full,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+-----+----+----+-----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+-----+----+----+-----+
| | | | 10 | 10 | 100 |
| | | | 12 | 10 | 40 |
| 11 | 8 | 110 | | | |
| 5 | 5 | 50 | 2 | 2 | 80 |
| 9 | 8 | 90 | | | |
+----+----+-----+----+----+-----+
"));
assert_join_metrics!(metrics, 5);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_left_semi_with_filter(
#[values(1, 2, 16)] batch_size: usize,
) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::LeftSemi,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+----+
| a1 | b1 | c1 |
+----+----+----+
| 5 | 5 | 50 |
+----+----+----+
"));
assert_join_metrics!(metrics, 1);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_left_anti_with_filter(
#[values(1, 2, 16)] batch_size: usize,
) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::LeftAnti,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+-----+
| a1 | b1 | c1 |
+----+----+-----+
| 11 | 8 | 110 |
| 9 | 8 | 90 |
+----+----+-----+
"));
assert_join_metrics!(metrics, 2);
Ok(())
}
#[tokio::test]
async fn join_has_correct_stats() -> Result<()> {
let left = build_left_table();
let right = build_right_table();
let nested_loop_join = NestedLoopJoinExec::try_new(
left,
right,
None,
&JoinType::Left,
Some(vec![1, 2]),
)?;
let stats = nested_loop_join.partition_statistics(None)?;
assert_eq!(
nested_loop_join.schema().fields().len(),
stats.column_statistics.len(),
);
assert_eq!(2, stats.column_statistics.len());
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_right_semi_with_filter(
#[values(1, 2, 16)] batch_size: usize,
) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::RightSemi,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a2", "b2", "c2"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+----+
| a2 | b2 | c2 |
+----+----+----+
| 2 | 2 | 80 |
+----+----+----+
"));
assert_join_metrics!(metrics, 1);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_right_anti_with_filter(
#[values(1, 2, 16)] batch_size: usize,
) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::RightAnti,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a2", "b2", "c2"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+-----+
| a2 | b2 | c2 |
+----+----+-----+
| 10 | 10 | 100 |
| 12 | 10 | 40 |
+----+----+-----+
"));
assert_join_metrics!(metrics, 2);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_left_mark_with_filter(
#[values(1, 2, 16)] batch_size: usize,
) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::LeftMark,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+-----+-------+
| a1 | b1 | c1 | mark |
+----+----+-----+-------+
| 11 | 8 | 110 | false |
| 5 | 5 | 50 | true |
| 9 | 8 | 90 | false |
+----+----+-----+-------+
"));
assert_join_metrics!(metrics, 3);
Ok(())
}
#[rstest]
#[tokio::test]
async fn join_right_mark_with_filter(
#[values(1, 2, 16)] batch_size: usize,
) -> Result<()> {
let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
let (columns, batches, metrics) = multi_partitioned_join_collect(
left,
right,
&JoinType::RightMark,
Some(filter),
task_ctx,
)
.await?;
assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+-----+-------+
| a2 | b2 | c2 | mark |
+----+----+-----+-------+
| 10 | 10 | 100 | false |
| 12 | 10 | 40 | false |
| 2 | 2 | 80 | true |
+----+----+-----+-------+
"));
assert_join_metrics!(metrics, 3);
Ok(())
}
#[tokio::test]
async fn test_overallocation() -> Result<()> {
let left = build_table(
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
None,
Vec::new(),
);
let right = build_table(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
None,
Vec::new(),
);
let filter = prepare_join_filter();
let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::LeftMark,
JoinType::RightSemi,
JoinType::RightAnti,
JoinType::RightMark,
];
for join_type in join_types {
let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(100, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let err = multi_partitioned_join_collect(
Arc::clone(&left),
Arc::clone(&right),
&join_type,
Some(filter.clone()),
task_ctx,
)
.await
.unwrap_err();
assert_contains!(
err.to_string(),
"Resources exhausted: Additional allocation failed for NestedLoopJoinLoad[0] with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]"
);
}
Ok(())
}
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
}
}