datafusion_physical_plan/joins/
nested_loop_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates).
19
20use std::any::Any;
21use std::fmt::Formatter;
22use std::ops::{BitOr, ControlFlow};
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::task::Poll;
26
27use super::utils::{
28    asymmetric_join_output_partitioning, need_produce_result_in_final,
29    reorder_output_after_swap, swap_join_projection,
30};
31use crate::common::can_project;
32use crate::execution_plan::{boundedness_from_children, EmissionType};
33use crate::joins::utils::{
34    build_join_schema, check_join_is_valid, estimate_join_statistics,
35    need_produce_right_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
36    OnceAsync, OnceFut,
37};
38use crate::joins::SharedBitmapBuilder;
39use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricsSet};
40use crate::projection::{
41    try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData,
42    ProjectionExec,
43};
44use crate::{
45    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
46    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
47};
48
49use arrow::array::{
50    new_null_array, Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions,
51};
52use arrow::buffer::BooleanBuffer;
53use arrow::compute::{concat_batches, filter, filter_record_batch, not, BatchCoalescer};
54use arrow::datatypes::{Schema, SchemaRef};
55use arrow::record_batch::RecordBatch;
56use datafusion_common::cast::as_boolean_array;
57use datafusion_common::{
58    arrow_err, internal_datafusion_err, internal_err, project_schema,
59    unwrap_or_internal_err, DataFusionError, JoinSide, Result, ScalarValue, Statistics,
60};
61use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
62use datafusion_execution::TaskContext;
63use datafusion_expr::JoinType;
64use datafusion_physical_expr::equivalence::{
65    join_equivalence_properties, ProjectionMapping,
66};
67
68use futures::{Stream, StreamExt, TryStreamExt};
69use log::debug;
70use parking_lot::Mutex;
71
72#[allow(rustdoc::private_intra_doc_links)]
73/// NestedLoopJoinExec is a build-probe join operator designed for joins that
74/// do not have equijoin keys in their `ON` clause.
75///
76/// # Execution Flow
77///
78/// ```text
79///                                                Incoming right batch
80///                Left Side Buffered Batches
81///                       ┌───────────┐              ┌───────────────┐
82///                       │ ┌───────┐ │              │               │
83///                       │ │       │ │              │               │
84///  Current Left Row ───▶│ ├───────├─┤──────────┐   │               │
85///                       │ │       │ │          │   └───────────────┘
86///                       │ │       │ │          │           │
87///                       │ │       │ │          │           │
88///                       │ └───────┘ │          │           │
89///                       │ ┌───────┐ │          │           │
90///                       │ │       │ │          │     ┌─────┘
91///                       │ │       │ │          │     │
92///                       │ │       │ │          │     │
93///                       │ │       │ │          │     │
94///                       │ │       │ │          │     │
95///                       │ └───────┘ │          ▼     ▼
96///                       │   ......  │  ┌──────────────────────┐
97///                       │           │  │X (Cartesian Product) │
98///                       │           │  └──────────┬───────────┘
99///                       └───────────┘             │
100///                                                 │
101///                                                 ▼
102///                                      ┌───────┬───────────────┐
103///                                      │       │               │
104///                                      │       │               │
105///                                      │       │               │
106///                                      └───────┴───────────────┘
107///                                        Intermediate Batch
108///                                  (For join predicate evaluation)
109/// ```
110///
111/// The execution follows a two-phase design:
112///
113/// ## 1. Buffering Left Input
114/// - The operator eagerly buffers all left-side input batches into memory,
115///   util a memory limit is reached.
116///   Currently, an out-of-memory error will be thrown if all the left-side input batches
117///   cannot fit into memory at once.
118///   In the future, it's possible to make this case finish execution. (see
119///   'Memory-limited Execution' section)
120/// - The rationale for buffering the left side is that scanning the right side
121///   can be expensive (e.g., decoding Parquet files), so buffering more left
122///   rows reduces the number of right-side scan passes required.
123///
124/// ## 2. Probing Right Input
125/// - Right-side input is streamed batch by batch.
126/// - For each right-side batch:
127///   - It evaluates the join filter against the full buffered left input.
128///     This results in a Cartesian product between the right batch and each
129///     left row -- with the join predicate/filter applied -- for each inner
130///     loop iteration.
131///   - Matched results are accumulated into an output buffer. (see more in
132///     `Output Buffering Strategy` section)
133/// - This process continues until all right-side input is consumed.
134///
135/// # Producing unmatched build-side data
136/// - For special join types like left/full joins, it's required to also output
137///   unmatched pairs. During execution, bitmaps are kept for both left and right
138///   sides of the input; they'll be handled by dedicated states in `NLJStream`.
139/// - The final output of the left side unmatched rows is handled by a single
140///   partition for simplicity, since it only counts a small portion of the
141///   execution time. (e.g. if probe side has 10k rows, the final output of
142///   unmatched build side only roughly counts for 1/10k of the total time)
143///
144/// # Output Buffering Strategy
145/// The operator uses an intermediate output buffer to accumulate results. Once
146/// the output threshold is reached (currently set to the same value as
147/// `batch_size` in the configuration), the results will be eagerly output.
148///
149/// # Extra Notes
150/// - The operator always considers the **left** side as the build (buffered) side.
151///   Therefore, the physical optimizer should assign the smaller input to the left.
152/// - The design try to minimize the intermediate data size to approximately
153///   1 batch, for better cache locality and memory efficiency.
154///
155/// # TODO: Memory-limited Execution
156/// If the memory budget is exceeded during left-side buffering, fallback
157/// strategies such as streaming left batches and re-scanning the right side
158/// may be implemented in the future.
159///
160/// Tracking issue: <https://github.com/apache/datafusion/issues/15760>
161///
162/// # Clone / Shared State
163/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
164/// loading of the left side with the processing in each output stream.
165/// Therefore it can not be [`Clone`]
166#[derive(Debug)]
167pub struct NestedLoopJoinExec {
168    /// left side
169    pub(crate) left: Arc<dyn ExecutionPlan>,
170    /// right side
171    pub(crate) right: Arc<dyn ExecutionPlan>,
172    /// Filters which are applied while finding matching rows
173    pub(crate) filter: Option<JoinFilter>,
174    /// How the join is performed
175    pub(crate) join_type: JoinType,
176    /// The schema once the join is applied
177    join_schema: SchemaRef,
178    /// Future that consumes left input and buffers it in memory
179    ///
180    /// This structure is *shared* across all output streams.
181    ///
182    /// Each output stream waits on the `OnceAsync` to signal the completion of
183    /// the build(left) side data, and buffer them all for later joining.
184    build_side_data: OnceAsync<JoinLeftData>,
185    /// Information of index and left / right placement of columns
186    column_indices: Vec<ColumnIndex>,
187    /// Projection to apply to the output of the join
188    projection: Option<Vec<usize>>,
189
190    /// Execution metrics
191    metrics: ExecutionPlanMetricsSet,
192    /// Cache holding plan properties like equivalences, output partitioning etc.
193    cache: PlanProperties,
194}
195
196impl NestedLoopJoinExec {
197    /// Try to create a new [`NestedLoopJoinExec`]
198    pub fn try_new(
199        left: Arc<dyn ExecutionPlan>,
200        right: Arc<dyn ExecutionPlan>,
201        filter: Option<JoinFilter>,
202        join_type: &JoinType,
203        projection: Option<Vec<usize>>,
204    ) -> Result<Self> {
205        let left_schema = left.schema();
206        let right_schema = right.schema();
207        check_join_is_valid(&left_schema, &right_schema, &[])?;
208        let (join_schema, column_indices) =
209            build_join_schema(&left_schema, &right_schema, join_type);
210        let join_schema = Arc::new(join_schema);
211        let cache = Self::compute_properties(
212            &left,
213            &right,
214            Arc::clone(&join_schema),
215            *join_type,
216            projection.as_ref(),
217        )?;
218
219        Ok(NestedLoopJoinExec {
220            left,
221            right,
222            filter,
223            join_type: *join_type,
224            join_schema,
225            build_side_data: Default::default(),
226            column_indices,
227            projection,
228            metrics: Default::default(),
229            cache,
230        })
231    }
232
233    /// left side
234    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
235        &self.left
236    }
237
238    /// right side
239    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
240        &self.right
241    }
242
243    /// Filters applied before join output
244    pub fn filter(&self) -> Option<&JoinFilter> {
245        self.filter.as_ref()
246    }
247
248    /// How the join is performed
249    pub fn join_type(&self) -> &JoinType {
250        &self.join_type
251    }
252
253    pub fn projection(&self) -> Option<&Vec<usize>> {
254        self.projection.as_ref()
255    }
256
257    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
258    fn compute_properties(
259        left: &Arc<dyn ExecutionPlan>,
260        right: &Arc<dyn ExecutionPlan>,
261        schema: SchemaRef,
262        join_type: JoinType,
263        projection: Option<&Vec<usize>>,
264    ) -> Result<PlanProperties> {
265        // Calculate equivalence properties:
266        let mut eq_properties = join_equivalence_properties(
267            left.equivalence_properties().clone(),
268            right.equivalence_properties().clone(),
269            &join_type,
270            Arc::clone(&schema),
271            &Self::maintains_input_order(join_type),
272            None,
273            // No on columns in nested loop join
274            &[],
275        )?;
276
277        let mut output_partitioning =
278            asymmetric_join_output_partitioning(left, right, &join_type)?;
279
280        let emission_type = if left.boundedness().is_unbounded() {
281            EmissionType::Final
282        } else if right.pipeline_behavior() == EmissionType::Incremental {
283            match join_type {
284                // If we only need to generate matched rows from the probe side,
285                // we can emit rows incrementally.
286                JoinType::Inner
287                | JoinType::LeftSemi
288                | JoinType::RightSemi
289                | JoinType::Right
290                | JoinType::RightAnti
291                | JoinType::RightMark => EmissionType::Incremental,
292                // If we need to generate unmatched rows from the *build side*,
293                // we need to emit them at the end.
294                JoinType::Left
295                | JoinType::LeftAnti
296                | JoinType::LeftMark
297                | JoinType::Full => EmissionType::Both,
298            }
299        } else {
300            right.pipeline_behavior()
301        };
302
303        if let Some(projection) = projection {
304            // construct a map from the input expressions to the output expression of the Projection
305            let projection_mapping =
306                ProjectionMapping::from_indices(projection, &schema)?;
307            let out_schema = project_schema(&schema, Some(projection))?;
308            output_partitioning =
309                output_partitioning.project(&projection_mapping, &eq_properties);
310            eq_properties = eq_properties.project(&projection_mapping, out_schema);
311        }
312
313        Ok(PlanProperties::new(
314            eq_properties,
315            output_partitioning,
316            emission_type,
317            boundedness_from_children([left, right]),
318        ))
319    }
320
321    /// This join implementation does not preserve the input order of either side.
322    fn maintains_input_order(_join_type: JoinType) -> Vec<bool> {
323        vec![false, false]
324    }
325
326    pub fn contains_projection(&self) -> bool {
327        self.projection.is_some()
328    }
329
330    pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
331        // check if the projection is valid
332        can_project(&self.schema(), projection.as_ref())?;
333        let projection = match projection {
334            Some(projection) => match &self.projection {
335                Some(p) => Some(projection.iter().map(|i| p[*i]).collect()),
336                None => Some(projection),
337            },
338            None => None,
339        };
340        Self::try_new(
341            Arc::clone(&self.left),
342            Arc::clone(&self.right),
343            self.filter.clone(),
344            &self.join_type,
345            projection,
346        )
347    }
348
349    /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left
350    /// and right inputs swapped.
351    ///
352    /// # Notes:
353    ///
354    /// This function should be called BEFORE inserting any repartitioning
355    /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`]
356    /// for more details.
357    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
358        let left = self.left();
359        let right = self.right();
360        let new_join = NestedLoopJoinExec::try_new(
361            Arc::clone(right),
362            Arc::clone(left),
363            self.filter().map(JoinFilter::swap),
364            &self.join_type().swap(),
365            swap_join_projection(
366                left.schema().fields().len(),
367                right.schema().fields().len(),
368                self.projection.as_ref(),
369                self.join_type(),
370            ),
371        )?;
372
373        // For Semi/Anti joins, swap result will produce same output schema,
374        // no need to wrap them into additional projection
375        let plan: Arc<dyn ExecutionPlan> = if matches!(
376            self.join_type(),
377            JoinType::LeftSemi
378                | JoinType::RightSemi
379                | JoinType::LeftAnti
380                | JoinType::RightAnti
381        ) || self.projection.is_some()
382        {
383            Arc::new(new_join)
384        } else {
385            reorder_output_after_swap(
386                Arc::new(new_join),
387                &self.left().schema(),
388                &self.right().schema(),
389            )?
390        };
391
392        Ok(plan)
393    }
394}
395
396impl DisplayAs for NestedLoopJoinExec {
397    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
398        match t {
399            DisplayFormatType::Default | DisplayFormatType::Verbose => {
400                let display_filter = self.filter.as_ref().map_or_else(
401                    || "".to_string(),
402                    |f| format!(", filter={}", f.expression()),
403                );
404                let display_projections = if self.contains_projection() {
405                    format!(
406                        ", projection=[{}]",
407                        self.projection
408                            .as_ref()
409                            .unwrap()
410                            .iter()
411                            .map(|index| format!(
412                                "{}@{}",
413                                self.join_schema.fields().get(*index).unwrap().name(),
414                                index
415                            ))
416                            .collect::<Vec<_>>()
417                            .join(", ")
418                    )
419                } else {
420                    "".to_string()
421                };
422                write!(
423                    f,
424                    "NestedLoopJoinExec: join_type={:?}{}{}",
425                    self.join_type, display_filter, display_projections
426                )
427            }
428            DisplayFormatType::TreeRender => {
429                if *self.join_type() != JoinType::Inner {
430                    writeln!(f, "join_type={:?}", self.join_type)
431                } else {
432                    Ok(())
433                }
434            }
435        }
436    }
437}
438
439impl ExecutionPlan for NestedLoopJoinExec {
440    fn name(&self) -> &'static str {
441        "NestedLoopJoinExec"
442    }
443
444    fn as_any(&self) -> &dyn Any {
445        self
446    }
447
448    fn properties(&self) -> &PlanProperties {
449        &self.cache
450    }
451
452    fn required_input_distribution(&self) -> Vec<Distribution> {
453        vec![
454            Distribution::SinglePartition,
455            Distribution::UnspecifiedDistribution,
456        ]
457    }
458
459    fn maintains_input_order(&self) -> Vec<bool> {
460        Self::maintains_input_order(self.join_type)
461    }
462
463    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
464        vec![&self.left, &self.right]
465    }
466
467    fn with_new_children(
468        self: Arc<Self>,
469        children: Vec<Arc<dyn ExecutionPlan>>,
470    ) -> Result<Arc<dyn ExecutionPlan>> {
471        Ok(Arc::new(NestedLoopJoinExec::try_new(
472            Arc::clone(&children[0]),
473            Arc::clone(&children[1]),
474            self.filter.clone(),
475            &self.join_type,
476            self.projection.clone(),
477        )?))
478    }
479
480    fn execute(
481        &self,
482        partition: usize,
483        context: Arc<TaskContext>,
484    ) -> Result<SendableRecordBatchStream> {
485        if self.left.output_partitioning().partition_count() != 1 {
486            return internal_err!(
487                "Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
488                 consider using CoalescePartitionsExec or the EnforceDistribution rule"
489            );
490        }
491
492        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
493
494        // Initialization reservation for load of inner table
495        let load_reservation =
496            MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
497                .register(context.memory_pool());
498
499        let build_side_data = self.build_side_data.try_once(|| {
500            let stream = self.left.execute(0, Arc::clone(&context))?;
501
502            Ok(collect_left_input(
503                stream,
504                join_metrics.clone(),
505                load_reservation,
506                need_produce_result_in_final(self.join_type),
507                self.right().output_partitioning().partition_count(),
508            ))
509        })?;
510
511        let batch_size = context.session_config().batch_size();
512
513        let probe_side_data = self.right.execute(partition, context)?;
514
515        // update column indices to reflect the projection
516        let column_indices_after_projection = match &self.projection {
517            Some(projection) => projection
518                .iter()
519                .map(|i| self.column_indices[*i].clone())
520                .collect(),
521            None => self.column_indices.clone(),
522        };
523
524        Ok(Box::pin(NestedLoopJoinStream::new(
525            self.schema(),
526            self.filter.clone(),
527            self.join_type,
528            probe_side_data,
529            build_side_data,
530            column_indices_after_projection,
531            join_metrics,
532            batch_size,
533        )))
534    }
535
536    fn metrics(&self) -> Option<MetricsSet> {
537        Some(self.metrics.clone_inner())
538    }
539
540    fn statistics(&self) -> Result<Statistics> {
541        self.partition_statistics(None)
542    }
543
544    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
545        if partition.is_some() {
546            return Ok(Statistics::new_unknown(&self.schema()));
547        }
548        estimate_join_statistics(
549            self.left.partition_statistics(None)?,
550            self.right.partition_statistics(None)?,
551            vec![],
552            &self.join_type,
553            &self.join_schema,
554        )
555    }
556
557    /// Tries to push `projection` down through `nested_loop_join`. If possible, performs the
558    /// pushdown and returns a new [`NestedLoopJoinExec`] as the top plan which has projections
559    /// as its children. Otherwise, returns `None`.
560    fn try_swapping_with_projection(
561        &self,
562        projection: &ProjectionExec,
563    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
564        // TODO: currently if there is projection in NestedLoopJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later.
565        if self.contains_projection() {
566            return Ok(None);
567        }
568
569        if let Some(JoinData {
570            projected_left_child,
571            projected_right_child,
572            join_filter,
573            ..
574        }) = try_pushdown_through_join(
575            projection,
576            self.left(),
577            self.right(),
578            &[],
579            self.schema(),
580            self.filter(),
581        )? {
582            Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
583                Arc::new(projected_left_child),
584                Arc::new(projected_right_child),
585                join_filter,
586                self.join_type(),
587                // Returned early if projection is not None
588                None,
589            )?)))
590        } else {
591            try_embed_projection(projection, self)
592        }
593    }
594}
595
596impl EmbeddedProjection for NestedLoopJoinExec {
597    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
598        self.with_projection(projection)
599    }
600}
601
602/// Left (build-side) data
603pub(crate) struct JoinLeftData {
604    /// Build-side data collected to single batch
605    batch: RecordBatch,
606    /// Shared bitmap builder for visited left indices
607    bitmap: SharedBitmapBuilder,
608    /// Counter of running probe-threads, potentially able to update `bitmap`
609    probe_threads_counter: AtomicUsize,
610    /// Memory reservation for tracking batch and bitmap
611    /// Cleared on `JoinLeftData` drop
612    /// reservation is cleared on Drop
613    #[expect(dead_code)]
614    reservation: MemoryReservation,
615}
616
617impl JoinLeftData {
618    pub(crate) fn new(
619        batch: RecordBatch,
620        bitmap: SharedBitmapBuilder,
621        probe_threads_counter: AtomicUsize,
622        reservation: MemoryReservation,
623    ) -> Self {
624        Self {
625            batch,
626            bitmap,
627            probe_threads_counter,
628            reservation,
629        }
630    }
631
632    pub(crate) fn batch(&self) -> &RecordBatch {
633        &self.batch
634    }
635
636    pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder {
637        &self.bitmap
638    }
639
640    /// Decrements counter of running threads, and returns `true`
641    /// if caller is the last running thread
642    pub(crate) fn report_probe_completed(&self) -> bool {
643        self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
644    }
645}
646
647/// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it
648async fn collect_left_input(
649    stream: SendableRecordBatchStream,
650    join_metrics: BuildProbeJoinMetrics,
651    reservation: MemoryReservation,
652    with_visited_left_side: bool,
653    probe_threads_count: usize,
654) -> Result<JoinLeftData> {
655    let schema = stream.schema();
656
657    // Load all batches and count the rows
658    let (batches, metrics, mut reservation) = stream
659        .try_fold(
660            (Vec::new(), join_metrics, reservation),
661            |(mut batches, metrics, mut reservation), batch| async {
662                let batch_size = batch.get_array_memory_size();
663                // Reserve memory for incoming batch
664                reservation.try_grow(batch_size)?;
665                // Update metrics
666                metrics.build_mem_used.add(batch_size);
667                metrics.build_input_batches.add(1);
668                metrics.build_input_rows.add(batch.num_rows());
669                // Push batch to output
670                batches.push(batch);
671                Ok((batches, metrics, reservation))
672            },
673        )
674        .await?;
675
676    let merged_batch = concat_batches(&schema, &batches)?;
677
678    // Reserve memory for visited_left_side bitmap if required by join type
679    let visited_left_side = if with_visited_left_side {
680        let n_rows = merged_batch.num_rows();
681        let buffer_size = n_rows.div_ceil(8);
682        reservation.try_grow(buffer_size)?;
683        metrics.build_mem_used.add(buffer_size);
684
685        let mut buffer = BooleanBufferBuilder::new(n_rows);
686        buffer.append_n(n_rows, false);
687        buffer
688    } else {
689        BooleanBufferBuilder::new(0)
690    };
691
692    Ok(JoinLeftData::new(
693        merged_batch,
694        Mutex::new(visited_left_side),
695        AtomicUsize::new(probe_threads_count),
696        reservation,
697    ))
698}
699
700/// States for join processing. See `poll_next()` comment for more details about
701/// state transitions.
702#[derive(Debug, Clone, Copy)]
703enum NLJState {
704    BufferingLeft,
705    FetchingRight,
706    ProbeRight,
707    EmitRightUnmatched,
708    EmitLeftUnmatched,
709    Done,
710}
711pub(crate) struct NestedLoopJoinStream {
712    // ========================================================================
713    // PROPERTIES:
714    // Operator's properties that remain constant
715    //
716    // Note: The implementation uses the terms left/build-side table and
717    // right/probe-side table interchangeably. Treating the left side as the
718    // build side is a convention in DataFusion: the planner always tries to
719    // swap the smaller table to the left side.
720    // ========================================================================
721    /// Output schema
722    pub(crate) output_schema: Arc<Schema>,
723    /// join filter
724    pub(crate) join_filter: Option<JoinFilter>,
725    /// type of the join
726    pub(crate) join_type: JoinType,
727    /// the probe-side(right) table data of the nested loop join
728    pub(crate) right_data: SendableRecordBatchStream,
729    /// the build-side table data of the nested loop join
730    pub(crate) left_data: OnceFut<JoinLeftData>,
731    /// Projection to construct the output schema from the left and right tables.
732    /// Example:
733    /// - output_schema: ['a', 'c']
734    /// - left_schema: ['a', 'b']
735    /// - right_schema: ['c']
736    ///
737    /// The column indices would be [(left, 0), (right, 0)] -- taking the left
738    /// 0th column and right 0th column can construct the output schema.
739    ///
740    /// Note there are other columns ('b' in the example) still kept after
741    /// projection pushdown; this is because they might be used to evaluate
742    /// the join filter (e.g., `JOIN ON (b+c)>0`).
743    pub(crate) column_indices: Vec<ColumnIndex>,
744    /// Join execution metrics
745    pub(crate) join_metrics: BuildProbeJoinMetrics,
746
747    /// `batch_size` from configuration
748    batch_size: usize,
749
750    /// See comments in [`need_produce_right_in_final`] for more detail
751    should_track_unmatched_right: bool,
752
753    // ========================================================================
754    // STATE FLAGS/BUFFERS:
755    // Fields that hold intermediate data/flags during execution
756    // ========================================================================
757    /// State Tracking
758    state: NLJState,
759    /// Output buffer holds the join result to output. It will emit eagerly when
760    /// the threshold is reached.
761    output_buffer: Box<BatchCoalescer>,
762    /// See comments in [`NLJState::Done`] for its purpose
763    handled_empty_output: bool,
764
765    // Buffer(left) side
766    // -----------------
767    /// The current buffered left data to join
768    buffered_left_data: Option<Arc<JoinLeftData>>,
769    /// Index into the left buffered batch. Used in `ProbeRight` state
770    left_probe_idx: usize,
771    /// Index into the left buffered batch. Used in `EmitLeftUnmatched` state
772    left_emit_idx: usize,
773    /// Should we go back to `BufferingLeft` state again after `EmitLeftUnmatched`
774    /// state is over.
775    left_exhausted: bool,
776    /// If we can buffer all left data in one pass
777    /// TODO(now): this is for the (unimplemented) memory-limited execution
778    #[allow(dead_code)]
779    left_buffered_in_one_pass: bool,
780
781    // Probe(right) side
782    // -----------------
783    /// The current probe batch to process
784    current_right_batch: Option<RecordBatch>,
785    // For right join, keep track of matched rows in `current_right_batch`
786    // Constructed when fetching each new incoming right batch in `FetchingRight` state.
787    current_right_batch_matched: Option<BooleanArray>,
788}
789
790impl Stream for NestedLoopJoinStream {
791    type Item = Result<RecordBatch>;
792
793    /// See the comments [`NestedLoopJoinExec`] for high-level design ideas.
794    ///
795    /// # Implementation
796    ///
797    /// This function is the entry point of NLJ operator's state machine
798    /// transitions. The rough state transition graph is as follow, for more
799    /// details see the comment in each state's matching arm.
800    ///
801    /// ============================
802    /// State transition graph:
803    /// ============================
804    ///
805    /// (start) --> BufferingLeft
806    /// ----------------------------
807    /// BufferingLeft → FetchingRight
808    ///
809    /// FetchingRight → ProbeRight (if right batch available)
810    /// FetchingRight → EmitLeftUnmatched (if right exhausted)
811    ///
812    /// ProbeRight → ProbeRight (next left row or after yielding output)
813    /// ProbeRight → EmitRightUnmatched (for special join types like right join)
814    /// ProbeRight → FetchingRight (done with the current right batch)
815    ///
816    /// EmitRightUnmatched → FetchingRight
817    ///
818    /// EmitLeftUnmatched → EmitLeftUnmatched (only process 1 chunk for each
819    /// iteration)
820    /// EmitLeftUnmatched → Done (if finished)
821    /// ----------------------------
822    /// Done → (end)
823    fn poll_next(
824        mut self: std::pin::Pin<&mut Self>,
825        cx: &mut std::task::Context<'_>,
826    ) -> Poll<Option<Self::Item>> {
827        loop {
828            match self.state {
829                // # NLJState transitions
830                // --> FetchingRight
831                // This state will prepare the left side batches, next state
832                // `FetchingRight` is responsible for preparing a single probe
833                // side batch, before start joining.
834                NLJState::BufferingLeft => {
835                    debug!("[NLJState] Entering: {:?}", self.state);
836                    // inside `collect_left_input` (the routine to buffer build
837                    // -side batches), related metrics except build time will be
838                    // updated.
839                    // stop on drop
840                    let build_metric = self.join_metrics.build_time.clone();
841                    let _build_timer = build_metric.timer();
842
843                    match self.handle_buffering_left(cx) {
844                        ControlFlow::Continue(()) => continue,
845                        ControlFlow::Break(poll) => return poll,
846                    }
847                }
848
849                // # NLJState transitions:
850                // 1. --> ProbeRight
851                //    Start processing the join for the newly fetched right
852                //    batch.
853                // 2. --> EmitLeftUnmatched: When the right side input is exhausted, (maybe) emit
854                //    unmatched left side rows.
855                //
856                // After fetching a new batch from the right side, it will
857                // process all rows from the buffered left data:
858                // ```text
859                // for batch in right_side:
860                //     for row in left_buffer:
861                //         join(batch, row)
862                // ```
863                // Note: the implementation does this step incrementally,
864                // instead of materializing all intermediate Cartesian products
865                // at once in memory.
866                //
867                // So after the right side input is exhausted, the join phase
868                // for the current buffered left data is finished. We can go to
869                // the next `EmitLeftUnmatched` phase to check if there is any
870                // special handling (e.g., in cases like left join).
871                NLJState::FetchingRight => {
872                    debug!("[NLJState] Entering: {:?}", self.state);
873                    // stop on drop
874                    let join_metric = self.join_metrics.join_time.clone();
875                    let _join_timer = join_metric.timer();
876
877                    match self.handle_fetching_right(cx) {
878                        ControlFlow::Continue(()) => continue,
879                        ControlFlow::Break(poll) => return poll,
880                    }
881                }
882
883                // NLJState transitions:
884                // 1. --> ProbeRight(1)
885                //    If we have already buffered enough output to yield, it
886                //    will first give back control to the parent state machine,
887                //    then resume at the same place.
888                // 2. --> ProbeRight(2)
889                //    After probing one right batch, and evaluating the
890                //    join filter on (left-row x right-batch), it will advance
891                //    to the next left row, then re-enter the current state and
892                //    continue joining.
893                // 3. --> FetchRight
894                //    After it has done with the current right batch (to join
895                //    with all rows in the left buffer), it will go to
896                //    FetchRight state to check what to do next.
897                NLJState::ProbeRight => {
898                    debug!("[NLJState] Entering: {:?}", self.state);
899
900                    // stop on drop
901                    let join_metric = self.join_metrics.join_time.clone();
902                    let _join_timer = join_metric.timer();
903
904                    match self.handle_probe_right() {
905                        ControlFlow::Continue(()) => continue,
906                        ControlFlow::Break(poll) => {
907                            return self.join_metrics.baseline.record_poll(poll)
908                        }
909                    }
910                }
911
912                // In the `current_right_batch_matched` bitmap, all trues mean
913                // it has been output by the join. In this state we have to
914                // output unmatched rows for current right batch (with null
915                // padding for left relation)
916                // Precondition: we have checked the join type so that it's
917                // possible to output right unmatched (e.g. it's right join)
918                NLJState::EmitRightUnmatched => {
919                    debug!("[NLJState] Entering: {:?}", self.state);
920
921                    // stop on drop
922                    let join_metric = self.join_metrics.join_time.clone();
923                    let _join_timer = join_metric.timer();
924
925                    match self.handle_emit_right_unmatched() {
926                        ControlFlow::Continue(()) => continue,
927                        ControlFlow::Break(poll) => {
928                            return self.join_metrics.baseline.record_poll(poll)
929                        }
930                    }
931                }
932
933                // NLJState transitions:
934                // 1. --> EmitLeftUnmatched(1)
935                //    If we have already buffered enough output to yield, it
936                //    will first give back control to the parent state machine,
937                //    then resume at the same place.
938                // 2. --> EmitLeftUnmatched(2)
939                //    After processing some unmatched rows, it will re-enter
940                //    the same state, to check if there are any more final
941                //    results to output.
942                // 3. --> Done
943                //    It has processed all data, go to the final state and ready
944                //    to exit.
945                //
946                // TODO: For memory-limited case, go back to `BufferingLeft`
947                // state again.
948                NLJState::EmitLeftUnmatched => {
949                    debug!("[NLJState] Entering: {:?}", self.state);
950
951                    // stop on drop
952                    let join_metric = self.join_metrics.join_time.clone();
953                    let _join_timer = join_metric.timer();
954
955                    match self.handle_emit_left_unmatched() {
956                        ControlFlow::Continue(()) => continue,
957                        ControlFlow::Break(poll) => {
958                            return self.join_metrics.baseline.record_poll(poll)
959                        }
960                    }
961                }
962
963                // The final state and the exit point
964                NLJState::Done => {
965                    debug!("[NLJState] Entering: {:?}", self.state);
966
967                    // stop on drop
968                    let join_metric = self.join_metrics.join_time.clone();
969                    let _join_timer = join_metric.timer();
970                    // counting it in join timer due to there might be some
971                    // final resout batches to output in this state
972
973                    let poll = self.handle_done();
974                    return self.join_metrics.baseline.record_poll(poll);
975                }
976            }
977        }
978    }
979}
980
981impl RecordBatchStream for NestedLoopJoinStream {
982    fn schema(&self) -> SchemaRef {
983        Arc::clone(&self.output_schema)
984    }
985}
986
987impl NestedLoopJoinStream {
988    #[allow(clippy::too_many_arguments)]
989    pub(crate) fn new(
990        schema: Arc<Schema>,
991        filter: Option<JoinFilter>,
992        join_type: JoinType,
993        right_data: SendableRecordBatchStream,
994        left_data: OnceFut<JoinLeftData>,
995        column_indices: Vec<ColumnIndex>,
996        join_metrics: BuildProbeJoinMetrics,
997        batch_size: usize,
998    ) -> Self {
999        Self {
1000            output_schema: Arc::clone(&schema),
1001            join_filter: filter,
1002            join_type,
1003            right_data,
1004            column_indices,
1005            left_data,
1006            join_metrics,
1007            buffered_left_data: None,
1008            output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)),
1009            batch_size,
1010            current_right_batch: None,
1011            current_right_batch_matched: None,
1012            state: NLJState::BufferingLeft,
1013            left_probe_idx: 0,
1014            left_emit_idx: 0,
1015            left_exhausted: false,
1016            left_buffered_in_one_pass: true,
1017            handled_empty_output: false,
1018            should_track_unmatched_right: need_produce_right_in_final(join_type),
1019        }
1020    }
1021
1022    // ==== State handler functions ====
1023
1024    /// Handle BufferingLeft state - prepare left side batches
1025    fn handle_buffering_left(
1026        &mut self,
1027        cx: &mut std::task::Context<'_>,
1028    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1029        match self.left_data.get_shared(cx) {
1030            Poll::Ready(Ok(left_data)) => {
1031                self.buffered_left_data = Some(left_data);
1032                // TODO: implement memory-limited case
1033                self.left_exhausted = true;
1034                self.state = NLJState::FetchingRight;
1035                // Continue to next state immediately
1036                ControlFlow::Continue(())
1037            }
1038            Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1039            Poll::Pending => ControlFlow::Break(Poll::Pending),
1040        }
1041    }
1042
1043    /// Handle FetchingRight state - fetch next right batch and prepare for processing
1044    fn handle_fetching_right(
1045        &mut self,
1046        cx: &mut std::task::Context<'_>,
1047    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1048        match self.right_data.poll_next_unpin(cx) {
1049            Poll::Ready(result) => match result {
1050                Some(Ok(right_batch)) => {
1051                    // Update metrics
1052                    let right_batch_size = right_batch.num_rows();
1053                    self.join_metrics.input_rows.add(right_batch_size);
1054                    self.join_metrics.input_batches.add(1);
1055
1056                    // Skip the empty batch
1057                    if right_batch_size == 0 {
1058                        return ControlFlow::Continue(());
1059                    }
1060
1061                    self.current_right_batch = Some(right_batch);
1062
1063                    // Prepare right bitmap
1064                    if self.should_track_unmatched_right {
1065                        let zeroed_buf = BooleanBuffer::new_unset(right_batch_size);
1066                        self.current_right_batch_matched =
1067                            Some(BooleanArray::new(zeroed_buf, None));
1068                    }
1069
1070                    self.left_probe_idx = 0;
1071                    self.state = NLJState::ProbeRight;
1072                    ControlFlow::Continue(())
1073                }
1074                Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1075                None => {
1076                    // Right stream exhausted
1077                    self.state = NLJState::EmitLeftUnmatched;
1078                    ControlFlow::Continue(())
1079                }
1080            },
1081            Poll::Pending => ControlFlow::Break(Poll::Pending),
1082        }
1083    }
1084
1085    /// Handle ProbeRight state - process current probe batch
1086    fn handle_probe_right(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1087        // Return any completed batches first
1088        if let Some(poll) = self.maybe_flush_ready_batch() {
1089            return ControlFlow::Break(poll);
1090        }
1091
1092        // Process current probe state
1093        match self.process_probe_batch() {
1094            // State unchanged (ProbeRight)
1095            // Continue probing until we have done joining the
1096            // current right batch with all buffered left rows.
1097            Ok(true) => ControlFlow::Continue(()),
1098            // To next FetchRightState
1099            // We have finished joining
1100            // (cur_right_batch x buffered_left_batches)
1101            Ok(false) => {
1102                // Left exhausted, transition to FetchingRight
1103                self.left_probe_idx = 0;
1104                if self.should_track_unmatched_right {
1105                    debug_assert!(
1106                        self.current_right_batch_matched.is_some(),
1107                        "If it's required to track matched rows in the right input, the right bitmap must be present"
1108                    );
1109                    self.state = NLJState::EmitRightUnmatched;
1110                } else {
1111                    self.current_right_batch = None;
1112                    self.state = NLJState::FetchingRight;
1113                }
1114                ControlFlow::Continue(())
1115            }
1116            Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1117        }
1118    }
1119
1120    /// Handle EmitRightUnmatched state - emit unmatched right rows
1121    fn handle_emit_right_unmatched(
1122        &mut self,
1123    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1124        // Return any completed batches first
1125        if let Some(poll) = self.maybe_flush_ready_batch() {
1126            return ControlFlow::Break(poll);
1127        }
1128
1129        debug_assert!(
1130            self.current_right_batch_matched.is_some()
1131                && self.current_right_batch.is_some(),
1132            "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present"
1133        );
1134
1135        // Construct the result batch for unmatched right rows using a utility function
1136        match self.process_right_unmatched() {
1137            Ok(Some(batch)) => {
1138                match self.output_buffer.push_batch(batch) {
1139                    Ok(()) => {
1140                        // Processed all in one pass
1141                        // cleared inside `process_right_unmatched`
1142                        debug_assert!(self.current_right_batch.is_none());
1143                        self.state = NLJState::FetchingRight;
1144                        ControlFlow::Continue(())
1145                    }
1146                    Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1147                }
1148            }
1149            Ok(None) => {
1150                // Processed all in one pass
1151                // cleared inside `process_right_unmatched`
1152                debug_assert!(self.current_right_batch.is_none());
1153                self.state = NLJState::FetchingRight;
1154                ControlFlow::Continue(())
1155            }
1156            Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1157        }
1158    }
1159
1160    /// Handle EmitLeftUnmatched state - emit unmatched left rows
1161    fn handle_emit_left_unmatched(
1162        &mut self,
1163    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1164        // Return any completed batches first
1165        if let Some(poll) = self.maybe_flush_ready_batch() {
1166            return ControlFlow::Break(poll);
1167        }
1168
1169        // Process current unmatched state
1170        match self.process_left_unmatched() {
1171            // State unchanged (EmitLeftUnmatched)
1172            // Continue processing until we have processed all unmatched rows
1173            Ok(true) => ControlFlow::Continue(()),
1174            // To Done state
1175            // We have finished processing all unmatched rows
1176            Ok(false) => match self.output_buffer.finish_buffered_batch() {
1177                Ok(()) => {
1178                    self.state = NLJState::Done;
1179                    ControlFlow::Continue(())
1180                }
1181                Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1182            },
1183            Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1184        }
1185    }
1186
1187    /// Handle Done state - final state processing
1188    fn handle_done(&mut self) -> Poll<Option<Result<RecordBatch>>> {
1189        // Return any remaining completed batches before final termination
1190        if let Some(poll) = self.maybe_flush_ready_batch() {
1191            return poll;
1192        }
1193
1194        // HACK for the doc test in https://github.com/apache/datafusion/blob/main/datafusion/core/src/dataframe/mod.rs#L1265
1195        // If this operator directly return `Poll::Ready(None)`
1196        // for empty result, the final result will become an empty
1197        // batch with empty schema, however the expected result
1198        // should be with the expected schema for this operator
1199        if !self.handled_empty_output {
1200            let zero_count = Count::new();
1201            if *self.join_metrics.baseline.output_rows() == zero_count {
1202                let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema));
1203                self.handled_empty_output = true;
1204                return Poll::Ready(Some(Ok(empty_batch)));
1205            }
1206        }
1207
1208        Poll::Ready(None)
1209    }
1210
1211    // ==== Core logic handling for each state ====
1212
1213    /// Returns bool to indicate should it continue probing
1214    /// true -> continue in the same ProbeRight state
1215    /// false -> It has done with the (buffered_left x cur_right_batch), go to
1216    /// next state (ProbeRight)
1217    fn process_probe_batch(&mut self) -> Result<bool> {
1218        let left_data = Arc::clone(self.get_left_data()?);
1219        let right_batch = self
1220            .current_right_batch
1221            .as_ref()
1222            .ok_or_else(|| internal_datafusion_err!("Right batch should be available"))?
1223            .clone();
1224
1225        // stop probing, the caller will go to the next state
1226        if self.left_probe_idx >= left_data.batch().num_rows() {
1227            return Ok(false);
1228        }
1229
1230        // ========
1231        // Join (l_row x right_batch)
1232        // and push the result into output_buffer
1233        // ========
1234
1235        let l_idx = self.left_probe_idx;
1236        let join_batch =
1237            self.process_single_left_row_join(&left_data, &right_batch, l_idx)?;
1238
1239        if let Some(batch) = join_batch {
1240            self.output_buffer.push_batch(batch)?;
1241        }
1242
1243        // ==== Prepare for the next iteration ====
1244
1245        // Advance left cursor
1246        self.left_probe_idx += 1;
1247
1248        // Return true to continue probing
1249        Ok(true)
1250    }
1251
1252    /// Process a single left row join with the current right batch.
1253    /// Returns a RecordBatch containing the join results (None if empty)
1254    fn process_single_left_row_join(
1255        &mut self,
1256        left_data: &JoinLeftData,
1257        right_batch: &RecordBatch,
1258        l_index: usize,
1259    ) -> Result<Option<RecordBatch>> {
1260        let right_row_count = right_batch.num_rows();
1261        if right_row_count == 0 {
1262            return Ok(None);
1263        }
1264
1265        let cur_right_bitmap = if let Some(filter) = &self.join_filter {
1266            apply_filter_to_row_join_batch(
1267                left_data.batch(),
1268                l_index,
1269                right_batch,
1270                filter,
1271            )?
1272        } else {
1273            BooleanArray::from(vec![true; right_row_count])
1274        };
1275
1276        self.update_matched_bitmap(l_index, &cur_right_bitmap)?;
1277
1278        // For the following join types: here we only have to set the left/right
1279        // bitmap, and no need to output result
1280        if matches!(
1281            self.join_type,
1282            JoinType::LeftAnti
1283                | JoinType::LeftSemi
1284                | JoinType::LeftMark
1285                | JoinType::RightAnti
1286                | JoinType::RightMark
1287                | JoinType::RightSemi
1288        ) {
1289            return Ok(None);
1290        }
1291
1292        if cur_right_bitmap.true_count() == 0 {
1293            // If none of the pairs has passed the join predicate/filter
1294            Ok(None)
1295        } else {
1296            // Use the optimized approach similar to build_intermediate_batch_for_single_left_row
1297            let join_batch = build_row_join_batch(
1298                &self.output_schema,
1299                left_data.batch(),
1300                l_index,
1301                right_batch,
1302                Some(cur_right_bitmap),
1303                &self.column_indices,
1304                JoinSide::Left,
1305            )?;
1306            Ok(join_batch)
1307        }
1308    }
1309
1310    /// Returns bool to indicate should it continue processing unmatched rows
1311    /// true -> continue in the same EmitLeftUnmatched state
1312    /// false -> next state (Done)
1313    fn process_left_unmatched(&mut self) -> Result<bool> {
1314        let left_data = self.get_left_data()?;
1315        let left_batch = left_data.batch();
1316
1317        // ========
1318        // Check early return conditions
1319        // ========
1320
1321        // Early return if join type can't have unmatched rows
1322        let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
1323        // Early return if another thread is already processing unmatched rows
1324        let handled_by_other_partition =
1325            self.left_emit_idx == 0 && !left_data.report_probe_completed();
1326        // Stop processing unmatched rows, the caller will go to the next state
1327        let finished = self.left_emit_idx >= left_batch.num_rows();
1328
1329        if join_type_no_produce_left || handled_by_other_partition || finished {
1330            return Ok(false);
1331        }
1332
1333        // ========
1334        // Process unmatched rows and push the result into output_buffer
1335        // Each time, the number to process is up to batch size
1336        // ========
1337        let start_idx = self.left_emit_idx;
1338        let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
1339
1340        if let Some(batch) =
1341            self.process_left_unmatched_range(left_data, start_idx, end_idx)?
1342        {
1343            self.output_buffer.push_batch(batch)?;
1344        }
1345
1346        // ==== Prepare for the next iteration ====
1347        self.left_emit_idx = end_idx;
1348
1349        // Return true to continue processing unmatched rows
1350        Ok(true)
1351    }
1352
1353    /// Process unmatched rows from the left data within the specified range.
1354    /// Returns a RecordBatch containing the unmatched rows (None if empty).
1355    ///
1356    /// # Arguments
1357    /// * `left_data` - The left side data containing the batch and bitmap
1358    /// * `start_idx` - Start index (inclusive) of the range to process
1359    /// * `end_idx` - End index (exclusive) of the range to process
1360    ///
1361    /// # Safety
1362    /// The caller is responsible for ensuring that `start_idx` and `end_idx` are
1363    /// within valid bounds of the left batch. This function does not perform
1364    /// bounds checking.
1365    fn process_left_unmatched_range(
1366        &self,
1367        left_data: &JoinLeftData,
1368        start_idx: usize,
1369        end_idx: usize,
1370    ) -> Result<Option<RecordBatch>> {
1371        if start_idx == end_idx {
1372            return Ok(None);
1373        }
1374
1375        // Slice both left batch, and bitmap to range [start_idx, end_idx)
1376        // The range is bit index (not byte)
1377        let left_batch = left_data.batch();
1378        let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx);
1379
1380        // Can this be more efficient?
1381        let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx);
1382        bitmap_sliced.append_n(end_idx - start_idx, false);
1383        let bitmap = left_data.bitmap().lock();
1384        for i in start_idx..end_idx {
1385            assert!(
1386                i - start_idx < bitmap_sliced.capacity(),
1387                "DBG: {start_idx}, {end_idx}"
1388            );
1389            bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i));
1390        }
1391        let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None);
1392
1393        build_unmatched_batch(
1394            Arc::clone(&self.output_schema),
1395            &left_batch_sliced,
1396            bitmap_sliced,
1397            self.right_data.schema(),
1398            &self.column_indices,
1399            self.join_type,
1400            JoinSide::Left,
1401        )
1402    }
1403
1404    /// Process unmatched rows from the current right batch and reset the bitmap.
1405    /// Returns a RecordBatch containing the unmatched right rows (None if empty).
1406    fn process_right_unmatched(&mut self) -> Result<Option<RecordBatch>> {
1407        // ==== Take current right batch and its bitmap ====
1408        let right_batch_bitmap: BooleanArray =
1409            std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| {
1410                internal_datafusion_err!("right bitmap should be available")
1411            })?;
1412
1413        let right_batch = self.current_right_batch.take();
1414        let cur_right_batch = unwrap_or_internal_err!(right_batch);
1415
1416        let left_data = self.get_left_data()?;
1417        let left_schema = left_data.batch().schema();
1418
1419        let res = build_unmatched_batch(
1420            Arc::clone(&self.output_schema),
1421            &cur_right_batch,
1422            right_batch_bitmap,
1423            left_schema,
1424            &self.column_indices,
1425            self.join_type,
1426            JoinSide::Right,
1427        );
1428
1429        // ==== Clean-up ====
1430        self.current_right_batch_matched = None;
1431
1432        res
1433    }
1434
1435    // ==== Utilities ====
1436
1437    /// Get the build-side data of the left input, errors if it's None
1438    fn get_left_data(&self) -> Result<&Arc<JoinLeftData>> {
1439        self.buffered_left_data
1440            .as_ref()
1441            .ok_or_else(|| internal_datafusion_err!("LeftData should be available"))
1442    }
1443
1444    /// Flush the `output_buffer` if there are batches ready to output
1445    /// None if no result batch ready.
1446    fn maybe_flush_ready_batch(&mut self) -> Option<Poll<Option<Result<RecordBatch>>>> {
1447        if self.output_buffer.has_completed_batch() {
1448            if let Some(batch) = self.output_buffer.next_completed_batch() {
1449                // HACK: this is not part of `BaselineMetrics` yet, so update it
1450                // manually
1451                self.join_metrics.output_batches.add(1);
1452
1453                return Some(Poll::Ready(Some(Ok(batch))));
1454            }
1455        }
1456
1457        None
1458    }
1459
1460    /// After joining (l_index@left_buffer x current_right_batch), it will result
1461    /// in a bitmap (the same length as current_right_batch) as the join match
1462    /// result. Use this bitmap to update the global bitmap, for special join
1463    /// types like full joins.
1464    ///
1465    /// Example:
1466    /// After joining l_index=1 (1-indexed row in the left buffer), and the
1467    /// current right batch with 3 elements, this function will be called with
1468    /// arguments: l_index = 1, r_matched = [false, false, true]
1469    /// - If the join type is FullJoin, the 1-index in the left bitmap will be
1470    ///   set to true, and also the right bitmap will be bitwise-ORed with the
1471    ///   input r_matched bitmap.
1472    /// - For join types that don't require output unmatched rows, this
1473    ///   function can be a no-op. For inner joins, this function is a no-op; for left
1474    ///   joins, only the left bitmap may be updated.
1475    fn update_matched_bitmap(
1476        &mut self,
1477        l_index: usize,
1478        r_matched_bitmap: &BooleanArray,
1479    ) -> Result<()> {
1480        let left_data = self.get_left_data()?;
1481
1482        // number of successfully joined pairs from (l_index x cur_right_batch)
1483        let joined_len = r_matched_bitmap.true_count();
1484
1485        // 1. Maybe update the left bitmap
1486        if need_produce_result_in_final(self.join_type) && (joined_len > 0) {
1487            let mut bitmap = left_data.bitmap().lock();
1488            bitmap.set_bit(l_index, true);
1489        }
1490
1491        // 2. Maybe updateh the right bitmap
1492        if self.should_track_unmatched_right {
1493            debug_assert!(self.current_right_batch_matched.is_some());
1494            // after bit-wise or, it will be put back
1495            let right_bitmap = std::mem::take(&mut self.current_right_batch_matched)
1496                .ok_or_else(|| {
1497                    internal_datafusion_err!("right batch's bitmap should be present")
1498                })?;
1499            let (buf, nulls) = right_bitmap.into_parts();
1500            debug_assert!(nulls.is_none());
1501            let updated_right_bitmap = buf.bitor(r_matched_bitmap.values());
1502
1503            self.current_right_batch_matched =
1504                Some(BooleanArray::new(updated_right_bitmap, None));
1505        }
1506
1507        Ok(())
1508    }
1509}
1510
1511// ==== Utilities ====
1512
1513/// Apply the join filter between:
1514/// (l_index th row in left buffer) x (right batch)
1515/// Returns a bitmap, with successfully joined indices set to true
1516fn apply_filter_to_row_join_batch(
1517    left_batch: &RecordBatch,
1518    l_index: usize,
1519    right_batch: &RecordBatch,
1520    filter: &JoinFilter,
1521) -> Result<BooleanArray> {
1522    debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0);
1523
1524    let intermediate_batch = if filter.schema.fields().is_empty() {
1525        // If filter is constant (e.g. literal `true`), empty batch can be used
1526        // in the later filter step.
1527        create_record_batch_with_empty_schema(
1528            Arc::new((*filter.schema).clone()),
1529            right_batch.num_rows(),
1530        )?
1531    } else {
1532        build_row_join_batch(
1533            &filter.schema,
1534            left_batch,
1535            l_index,
1536            right_batch,
1537            None,
1538            &filter.column_indices,
1539            JoinSide::Left,
1540        )?
1541        .ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))?
1542    };
1543
1544    let filter_result = filter
1545        .expression()
1546        .evaluate(&intermediate_batch)?
1547        .into_array(intermediate_batch.num_rows())?;
1548    let filter_arr = as_boolean_array(&filter_result)?;
1549
1550    // [Caution] This step has previously introduced bugs
1551    // The filter result is NOT a bitmap; it contains true/false/null values.
1552    // For example, 1 < NULL is evaluated to NULL. Therefore, we must combine (AND)
1553    // the boolean array with its null bitmap to construct a unified bitmap.
1554    let (is_filtered, nulls) = filter_arr.clone().into_parts();
1555    let bitmap_combined = match nulls {
1556        Some(nulls) => {
1557            let combined = nulls.inner() & &is_filtered;
1558            BooleanArray::new(combined, None)
1559        }
1560        None => BooleanArray::new(is_filtered, None),
1561    };
1562
1563    Ok(bitmap_combined)
1564}
1565
1566/// This function performs the following steps:
1567/// 1. Apply filter to probe-side batch
1568/// 2. Broadcast the left row (build_side_batch\[build_side_index\]) to the
1569///    filtered probe-side batch
1570/// 3. Concat them together according to `col_indices`, and return the result
1571///    (None if the result is empty)
1572///
1573/// Example:
1574/// build_side_batch:
1575/// a
1576/// ----
1577/// 1
1578/// 2
1579/// 3
1580///
1581/// # 0 index element in the build_side_batch (that is `1`) will be used
1582/// build_side_index: 0
1583///
1584/// probe_side_batch:
1585/// b
1586/// ----
1587/// 10
1588/// 20
1589/// 30
1590/// 40
1591///
1592/// # After applying it, only index 1 and 3 elements in probe_side_batch will be
1593/// # kept
1594/// probe_side_filter:
1595/// false
1596/// true
1597/// false
1598/// true
1599///
1600///
1601/// # Projections to the build/probe side batch, to construct the output batch
1602/// col_indices:
1603/// [(left, 0), (right, 0)]
1604///
1605/// build_side: left
1606///
1607/// ====
1608/// Result batch:
1609/// a b
1610/// ----
1611/// 1 20
1612/// 1 40
1613fn build_row_join_batch(
1614    output_schema: &Schema,
1615    build_side_batch: &RecordBatch,
1616    build_side_index: usize,
1617    probe_side_batch: &RecordBatch,
1618    probe_side_filter: Option<BooleanArray>,
1619    // See [`NLJStream`] struct's `column_indices` field for more detail
1620    col_indices: &[ColumnIndex],
1621    // If the build side is left or right, used to interpret the side information
1622    // in `col_indices`
1623    build_side: JoinSide,
1624) -> Result<Option<RecordBatch>> {
1625    debug_assert!(build_side != JoinSide::None);
1626
1627    // TODO(perf): since the output might be projection of right batch, this
1628    // filtering step is more efficient to be done inside the column_index loop
1629    let filtered_probe_batch = if let Some(filter) = probe_side_filter {
1630        &filter_record_batch(probe_side_batch, &filter)?
1631    } else {
1632        probe_side_batch
1633    };
1634
1635    if filtered_probe_batch.num_rows() == 0 {
1636        return Ok(None);
1637    }
1638
1639    // Edge case: downstream operator does not require any columns from this NLJ,
1640    // so allow an empty projection.
1641    // Example:
1642    //  SELECT DISTINCT 32 AS col2
1643    //  FROM tab0 AS cor0
1644    //  LEFT OUTER JOIN tab2 AS cor1
1645    //  ON ( NULL ) IS NULL;
1646    if output_schema.fields.is_empty() {
1647        return Ok(Some(create_record_batch_with_empty_schema(
1648            Arc::new(output_schema.clone()),
1649            filtered_probe_batch.num_rows(),
1650        )?));
1651    }
1652
1653    let mut columns: Vec<Arc<dyn Array>> =
1654        Vec::with_capacity(output_schema.fields().len());
1655
1656    for column_index in col_indices {
1657        let array = if column_index.side == build_side {
1658            // Broadcast the single build-side row to match the filtered
1659            // probe-side batch length
1660            let original_left_array = build_side_batch.column(column_index.index);
1661            let scalar_value = ScalarValue::try_from_array(
1662                original_left_array.as_ref(),
1663                build_side_index,
1664            )?;
1665            scalar_value.to_array_of_size(filtered_probe_batch.num_rows())?
1666        } else {
1667            // Take the filtered probe-side column using compute::take
1668            Arc::clone(filtered_probe_batch.column(column_index.index))
1669        };
1670
1671        columns.push(array);
1672    }
1673
1674    Ok(Some(RecordBatch::try_new(
1675        Arc::new(output_schema.clone()),
1676        columns,
1677    )?))
1678}
1679
1680/// Special case for `PlaceHolderRowExec`
1681/// Minimal example:  SELECT 1 WHERE EXISTS (SELECT 1);
1682//
1683/// # Return
1684/// If Some, that's the result batch
1685/// If None, it's not for this special case. Continue execution.
1686fn build_unmatched_batch_empty_schema(
1687    output_schema: SchemaRef,
1688    batch_bitmap: &BooleanArray,
1689    // For left/right/full joins, it needs to fill nulls for another side
1690    join_type: JoinType,
1691) -> Result<Option<RecordBatch>> {
1692    let result_size = match join_type {
1693        JoinType::Left
1694        | JoinType::Right
1695        | JoinType::Full
1696        | JoinType::LeftAnti
1697        | JoinType::RightAnti => batch_bitmap.false_count(),
1698        JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(),
1699        JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(),
1700        _ => unreachable!(),
1701    };
1702
1703    if output_schema.fields().is_empty() {
1704        Ok(Some(create_record_batch_with_empty_schema(
1705            Arc::clone(&output_schema),
1706            result_size,
1707        )?))
1708    } else {
1709        Ok(None)
1710    }
1711}
1712
1713/// Creates an empty RecordBatch with a specific row count.
1714/// This is useful for cases where we need a batch with the correct schema and row count
1715/// but no actual data columns (e.g., for constant filters).
1716fn create_record_batch_with_empty_schema(
1717    schema: SchemaRef,
1718    row_count: usize,
1719) -> Result<RecordBatch> {
1720    let options = RecordBatchOptions::new()
1721        .with_match_field_names(true)
1722        .with_row_count(Some(row_count));
1723
1724    RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| {
1725        internal_datafusion_err!("Failed to create empty record batch: {}", e)
1726    })
1727}
1728
1729/// # Example:
1730/// batch:
1731/// a
1732/// ----
1733/// 1
1734/// 2
1735/// 3
1736///
1737/// batch_bitmap:
1738/// ----
1739/// false
1740/// true
1741/// false
1742///
1743/// another_side_schema:
1744/// [(b, bool), (c, int32)]
1745///
1746/// join_type: JoinType::Left
1747///
1748/// col_indices: ...(please refer to the comment in `NLJStream::column_indices``)
1749///
1750/// batch_side: right
1751///
1752/// # Walkthrough:
1753///
1754/// This executor is performing a right join, and the currently processed right
1755/// batch is as above. After joining it with all buffered left rows, the joined
1756/// entries are marked by the `batch_bitmap`.
1757/// This method will keep the unmatched indices on the batch side (right), and pad
1758/// the left side with nulls. The result would be:
1759///
1760/// b          c           a
1761/// ------------------------
1762/// Null(bool) Null(Int32) 1
1763/// Null(bool) Null(Int32) 3
1764fn build_unmatched_batch(
1765    output_schema: SchemaRef,
1766    batch: &RecordBatch,
1767    batch_bitmap: BooleanArray,
1768    // For left/right/full joins, it needs to fill nulls for another side
1769    another_side_schema: SchemaRef,
1770    col_indices: &[ColumnIndex],
1771    join_type: JoinType,
1772    batch_side: JoinSide,
1773) -> Result<Option<RecordBatch>> {
1774    // Should not call it for inner joins
1775    debug_assert_ne!(join_type, JoinType::Inner);
1776    debug_assert_ne!(batch_side, JoinSide::None);
1777
1778    // Handle special case (see function comment)
1779    if let Some(batch) = build_unmatched_batch_empty_schema(
1780        Arc::clone(&output_schema),
1781        &batch_bitmap,
1782        join_type,
1783    )? {
1784        return Ok(Some(batch));
1785    }
1786
1787    match join_type {
1788        JoinType::Full | JoinType::Right | JoinType::Left => {
1789            if join_type == JoinType::Right {
1790                debug_assert_eq!(batch_side, JoinSide::Right);
1791            }
1792            if join_type == JoinType::Left {
1793                debug_assert_eq!(batch_side, JoinSide::Left);
1794            }
1795
1796            // 1. Filter the batch with *flipped* bitmap
1797            // 2. Fill left side with nulls
1798            let flipped_bitmap = not(&batch_bitmap)?;
1799
1800            // create a recordbatch, with left_schema, of only one row of all nulls
1801            let left_null_columns: Vec<Arc<dyn Array>> = another_side_schema
1802                .fields()
1803                .iter()
1804                .map(|field| new_null_array(field.data_type(), 1))
1805                .collect();
1806
1807            // Hack: If the left schema is not nullable, the full join result
1808            // might contain null, this is only a temporary batch to construct
1809            // such full join result.
1810            let nullable_left_schema = Arc::new(Schema::new(
1811                another_side_schema
1812                    .fields()
1813                    .iter()
1814                    .map(|field| {
1815                        (**field).clone().with_nullable(true)
1816                    })
1817                    .collect::<Vec<_>>(),
1818            ));
1819            let left_null_batch = if nullable_left_schema.fields.is_empty() {
1820                // Left input can be an empty relation, in this case left relation
1821                // won't be used to construct the result batch (i.e. not in `col_indices`)
1822                create_record_batch_with_empty_schema(nullable_left_schema, 0)?
1823            } else {
1824                RecordBatch::try_new(nullable_left_schema, left_null_columns)?
1825            };
1826
1827            debug_assert_ne!(batch_side, JoinSide::None);
1828            let opposite_side = batch_side.negate();
1829
1830            build_row_join_batch(&output_schema, &left_null_batch, 0, batch, Some(flipped_bitmap), col_indices, opposite_side)
1831
1832        },
1833        JoinType::RightSemi | JoinType::RightAnti | JoinType::LeftSemi | JoinType::LeftAnti => {
1834            if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) {
1835                debug_assert_eq!(batch_side, JoinSide::Right);
1836            }
1837            if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
1838                debug_assert_eq!(batch_side, JoinSide::Left);
1839            }
1840
1841            let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi) {
1842                batch_bitmap.clone()
1843            } else {
1844                not(&batch_bitmap)?
1845            };
1846
1847            if bitmap.true_count() == 0 {
1848                return Ok(None);
1849            }
1850
1851            let mut columns: Vec<Arc<dyn Array>> =
1852                Vec::with_capacity(output_schema.fields().len());
1853
1854            for column_index in col_indices {
1855                debug_assert!(column_index.side == batch_side);
1856
1857                let col = batch.column(column_index.index);
1858                let filtered_col = filter(col, &bitmap)?;
1859
1860                columns.push(filtered_col);
1861            }
1862
1863            Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1864        },
1865        JoinType::RightMark | JoinType::LeftMark => {
1866            if join_type == JoinType::RightMark {
1867                debug_assert_eq!(batch_side, JoinSide::Right);
1868            }
1869            if join_type == JoinType::LeftMark {
1870                debug_assert_eq!(batch_side, JoinSide::Left);
1871            }
1872
1873            let mut columns: Vec<Arc<dyn Array>> =
1874                Vec::with_capacity(output_schema.fields().len());
1875
1876            // Hack to deal with the borrow checker
1877            let mut right_batch_bitmap_opt = Some(batch_bitmap);
1878
1879            for column_index in col_indices {
1880                if column_index.side == batch_side {
1881                    let col = batch.column(column_index.index);
1882
1883                    columns.push(Arc::clone(col));
1884                } else if column_index.side == JoinSide::None {
1885                    let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt);
1886                    match right_batch_bitmap {
1887                        Some(right_batch_bitmap) => {columns.push(Arc::new(right_batch_bitmap))},
1888                        None => unreachable!("Should only be one mark column"),
1889                    }
1890                } else {
1891                    return internal_err!("Not possible to have this join side for RightMark join");
1892                }
1893            }
1894
1895            Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1896        }
1897        _ => internal_err!("If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"),
1898    }
1899}
1900
1901#[cfg(test)]
1902pub(crate) mod tests {
1903    use super::*;
1904    use crate::test::{assert_join_metrics, TestMemoryExec};
1905    use crate::{
1906        common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
1907    };
1908
1909    use arrow::compute::SortOptions;
1910    use arrow::datatypes::{DataType, Field};
1911    use datafusion_common::test_util::batches_to_sort_string;
1912    use datafusion_common::{assert_contains, ScalarValue};
1913    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1914    use datafusion_expr::Operator;
1915    use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
1916    use datafusion_physical_expr::{Partitioning, PhysicalExpr};
1917    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1918
1919    use insta::allow_duplicates;
1920    use insta::assert_snapshot;
1921    use rstest::rstest;
1922
1923    fn build_table(
1924        a: (&str, &Vec<i32>),
1925        b: (&str, &Vec<i32>),
1926        c: (&str, &Vec<i32>),
1927        batch_size: Option<usize>,
1928        sorted_column_names: Vec<&str>,
1929    ) -> Arc<dyn ExecutionPlan> {
1930        let batch = build_table_i32(a, b, c);
1931        let schema = batch.schema();
1932
1933        let batches = if let Some(batch_size) = batch_size {
1934            let num_batches = batch.num_rows().div_ceil(batch_size);
1935            (0..num_batches)
1936                .map(|i| {
1937                    let start = i * batch_size;
1938                    let remaining_rows = batch.num_rows() - start;
1939                    batch.slice(start, batch_size.min(remaining_rows))
1940                })
1941                .collect::<Vec<_>>()
1942        } else {
1943            vec![batch]
1944        };
1945
1946        let mut sort_info = vec![];
1947        for name in sorted_column_names {
1948            let index = schema.index_of(name).unwrap();
1949            let sort_expr = PhysicalSortExpr::new(
1950                Arc::new(Column::new(name, index)),
1951                SortOptions::new(false, false),
1952            );
1953            sort_info.push(sort_expr);
1954        }
1955        let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap();
1956        if let Some(ordering) = LexOrdering::new(sort_info) {
1957            source = source.try_with_sort_information(vec![ordering]).unwrap();
1958        }
1959
1960        Arc::new(TestMemoryExec::update_cache(Arc::new(source)))
1961    }
1962
1963    fn build_left_table() -> Arc<dyn ExecutionPlan> {
1964        build_table(
1965            ("a1", &vec![5, 9, 11]),
1966            ("b1", &vec![5, 8, 8]),
1967            ("c1", &vec![50, 90, 110]),
1968            None,
1969            Vec::new(),
1970        )
1971    }
1972
1973    fn build_right_table() -> Arc<dyn ExecutionPlan> {
1974        build_table(
1975            ("a2", &vec![12, 2, 10]),
1976            ("b2", &vec![10, 2, 10]),
1977            ("c2", &vec![40, 80, 100]),
1978            None,
1979            Vec::new(),
1980        )
1981    }
1982
1983    fn prepare_join_filter() -> JoinFilter {
1984        let column_indices = vec![
1985            ColumnIndex {
1986                index: 1,
1987                side: JoinSide::Left,
1988            },
1989            ColumnIndex {
1990                index: 1,
1991                side: JoinSide::Right,
1992            },
1993        ];
1994        let intermediate_schema = Schema::new(vec![
1995            Field::new("x", DataType::Int32, true),
1996            Field::new("x", DataType::Int32, true),
1997        ]);
1998        // left.b1!=8
1999        let left_filter = Arc::new(BinaryExpr::new(
2000            Arc::new(Column::new("x", 0)),
2001            Operator::NotEq,
2002            Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
2003        )) as Arc<dyn PhysicalExpr>;
2004        // right.b2!=10
2005        let right_filter = Arc::new(BinaryExpr::new(
2006            Arc::new(Column::new("x", 1)),
2007            Operator::NotEq,
2008            Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2009        )) as Arc<dyn PhysicalExpr>;
2010        // filter = left.b1!=8 and right.b2!=10
2011        // after filter:
2012        // left table:
2013        // ("a1", &vec![5]),
2014        // ("b1", &vec![5]),
2015        // ("c1", &vec![50]),
2016        // right table:
2017        // ("a2", &vec![12, 2]),
2018        // ("b2", &vec![10, 2]),
2019        // ("c2", &vec![40, 80]),
2020        let filter_expression =
2021            Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
2022                as Arc<dyn PhysicalExpr>;
2023
2024        JoinFilter::new(
2025            filter_expression,
2026            column_indices,
2027            Arc::new(intermediate_schema),
2028        )
2029    }
2030
2031    pub(crate) async fn multi_partitioned_join_collect(
2032        left: Arc<dyn ExecutionPlan>,
2033        right: Arc<dyn ExecutionPlan>,
2034        join_type: &JoinType,
2035        join_filter: Option<JoinFilter>,
2036        context: Arc<TaskContext>,
2037    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
2038        let partition_count = 4;
2039
2040        // Redistributing right input
2041        let right = Arc::new(RepartitionExec::try_new(
2042            right,
2043            Partitioning::RoundRobinBatch(partition_count),
2044        )?) as Arc<dyn ExecutionPlan>;
2045
2046        // Use the required distribution for nested loop join to test partition data
2047        let nested_loop_join =
2048            NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
2049        let columns = columns(&nested_loop_join.schema());
2050        let mut batches = vec![];
2051        for i in 0..partition_count {
2052            let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
2053            let more_batches = common::collect(stream).await?;
2054            batches.extend(
2055                more_batches
2056                    .into_iter()
2057                    .inspect(|b| {
2058                        assert!(b.num_rows() <= context.session_config().batch_size())
2059                    })
2060                    .filter(|b| b.num_rows() > 0)
2061                    .collect::<Vec<_>>(),
2062            );
2063        }
2064
2065        let metrics = nested_loop_join.metrics().unwrap();
2066
2067        Ok((columns, batches, metrics))
2068    }
2069
2070    fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
2071        let base = TaskContext::default();
2072        // limit max size of intermediate batch used in nlj to 1
2073        let cfg = base.session_config().clone().with_batch_size(batch_size);
2074        Arc::new(base.with_session_config(cfg))
2075    }
2076
2077    #[rstest]
2078    #[tokio::test]
2079    async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2080        let task_ctx = new_task_ctx(batch_size);
2081        dbg!(&batch_size);
2082        let left = build_left_table();
2083        let right = build_right_table();
2084        let filter = prepare_join_filter();
2085        let (columns, batches, metrics) = multi_partitioned_join_collect(
2086            left,
2087            right,
2088            &JoinType::Inner,
2089            Some(filter),
2090            task_ctx,
2091        )
2092        .await?;
2093
2094        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2095        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2096            +----+----+----+----+----+----+
2097            | a1 | b1 | c1 | a2 | b2 | c2 |
2098            +----+----+----+----+----+----+
2099            | 5  | 5  | 50 | 2  | 2  | 80 |
2100            +----+----+----+----+----+----+
2101            "#));
2102
2103        assert_join_metrics!(metrics, 1);
2104
2105        Ok(())
2106    }
2107
2108    #[rstest]
2109    #[tokio::test]
2110    async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2111        let task_ctx = new_task_ctx(batch_size);
2112        let left = build_left_table();
2113        let right = build_right_table();
2114
2115        let filter = prepare_join_filter();
2116        let (columns, batches, metrics) = multi_partitioned_join_collect(
2117            left,
2118            right,
2119            &JoinType::Left,
2120            Some(filter),
2121            task_ctx,
2122        )
2123        .await?;
2124        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2125        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2126            +----+----+-----+----+----+----+
2127            | a1 | b1 | c1  | a2 | b2 | c2 |
2128            +----+----+-----+----+----+----+
2129            | 11 | 8  | 110 |    |    |    |
2130            | 5  | 5  | 50  | 2  | 2  | 80 |
2131            | 9  | 8  | 90  |    |    |    |
2132            +----+----+-----+----+----+----+
2133            "#));
2134
2135        assert_join_metrics!(metrics, 3);
2136
2137        Ok(())
2138    }
2139
2140    #[rstest]
2141    #[tokio::test]
2142    async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2143        let task_ctx = new_task_ctx(batch_size);
2144        let left = build_left_table();
2145        let right = build_right_table();
2146
2147        let filter = prepare_join_filter();
2148        let (columns, batches, metrics) = multi_partitioned_join_collect(
2149            left,
2150            right,
2151            &JoinType::Right,
2152            Some(filter),
2153            task_ctx,
2154        )
2155        .await?;
2156        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2157        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2158            +----+----+----+----+----+-----+
2159            | a1 | b1 | c1 | a2 | b2 | c2  |
2160            +----+----+----+----+----+-----+
2161            |    |    |    | 10 | 10 | 100 |
2162            |    |    |    | 12 | 10 | 40  |
2163            | 5  | 5  | 50 | 2  | 2  | 80  |
2164            +----+----+----+----+----+-----+
2165            "#));
2166
2167        assert_join_metrics!(metrics, 3);
2168
2169        Ok(())
2170    }
2171
2172    #[rstest]
2173    #[tokio::test]
2174    async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2175        let task_ctx = new_task_ctx(batch_size);
2176        let left = build_left_table();
2177        let right = build_right_table();
2178
2179        let filter = prepare_join_filter();
2180        let (columns, batches, metrics) = multi_partitioned_join_collect(
2181            left,
2182            right,
2183            &JoinType::Full,
2184            Some(filter),
2185            task_ctx,
2186        )
2187        .await?;
2188        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2189        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2190            +----+----+-----+----+----+-----+
2191            | a1 | b1 | c1  | a2 | b2 | c2  |
2192            +----+----+-----+----+----+-----+
2193            |    |    |     | 10 | 10 | 100 |
2194            |    |    |     | 12 | 10 | 40  |
2195            | 11 | 8  | 110 |    |    |     |
2196            | 5  | 5  | 50  | 2  | 2  | 80  |
2197            | 9  | 8  | 90  |    |    |     |
2198            +----+----+-----+----+----+-----+
2199            "#));
2200
2201        assert_join_metrics!(metrics, 5);
2202
2203        Ok(())
2204    }
2205
2206    #[rstest]
2207    #[tokio::test]
2208    async fn join_left_semi_with_filter(
2209        #[values(1, 2, 16)] batch_size: usize,
2210    ) -> Result<()> {
2211        let task_ctx = new_task_ctx(batch_size);
2212        let left = build_left_table();
2213        let right = build_right_table();
2214
2215        let filter = prepare_join_filter();
2216        let (columns, batches, metrics) = multi_partitioned_join_collect(
2217            left,
2218            right,
2219            &JoinType::LeftSemi,
2220            Some(filter),
2221            task_ctx,
2222        )
2223        .await?;
2224        assert_eq!(columns, vec!["a1", "b1", "c1"]);
2225        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2226            +----+----+----+
2227            | a1 | b1 | c1 |
2228            +----+----+----+
2229            | 5  | 5  | 50 |
2230            +----+----+----+
2231            "#));
2232
2233        assert_join_metrics!(metrics, 1);
2234
2235        Ok(())
2236    }
2237
2238    #[rstest]
2239    #[tokio::test]
2240    async fn join_left_anti_with_filter(
2241        #[values(1, 2, 16)] batch_size: usize,
2242    ) -> Result<()> {
2243        let task_ctx = new_task_ctx(batch_size);
2244        let left = build_left_table();
2245        let right = build_right_table();
2246
2247        let filter = prepare_join_filter();
2248        let (columns, batches, metrics) = multi_partitioned_join_collect(
2249            left,
2250            right,
2251            &JoinType::LeftAnti,
2252            Some(filter),
2253            task_ctx,
2254        )
2255        .await?;
2256        assert_eq!(columns, vec!["a1", "b1", "c1"]);
2257        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2258            +----+----+-----+
2259            | a1 | b1 | c1  |
2260            +----+----+-----+
2261            | 11 | 8  | 110 |
2262            | 9  | 8  | 90  |
2263            +----+----+-----+
2264            "#));
2265
2266        assert_join_metrics!(metrics, 2);
2267
2268        Ok(())
2269    }
2270
2271    #[rstest]
2272    #[tokio::test]
2273    async fn join_right_semi_with_filter(
2274        #[values(1, 2, 16)] batch_size: usize,
2275    ) -> Result<()> {
2276        let task_ctx = new_task_ctx(batch_size);
2277        let left = build_left_table();
2278        let right = build_right_table();
2279
2280        let filter = prepare_join_filter();
2281        let (columns, batches, metrics) = multi_partitioned_join_collect(
2282            left,
2283            right,
2284            &JoinType::RightSemi,
2285            Some(filter),
2286            task_ctx,
2287        )
2288        .await?;
2289        assert_eq!(columns, vec!["a2", "b2", "c2"]);
2290        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2291            +----+----+----+
2292            | a2 | b2 | c2 |
2293            +----+----+----+
2294            | 2  | 2  | 80 |
2295            +----+----+----+
2296            "#));
2297
2298        assert_join_metrics!(metrics, 1);
2299
2300        Ok(())
2301    }
2302
2303    #[rstest]
2304    #[tokio::test]
2305    async fn join_right_anti_with_filter(
2306        #[values(1, 2, 16)] batch_size: usize,
2307    ) -> Result<()> {
2308        let task_ctx = new_task_ctx(batch_size);
2309        let left = build_left_table();
2310        let right = build_right_table();
2311
2312        let filter = prepare_join_filter();
2313        let (columns, batches, metrics) = multi_partitioned_join_collect(
2314            left,
2315            right,
2316            &JoinType::RightAnti,
2317            Some(filter),
2318            task_ctx,
2319        )
2320        .await?;
2321        assert_eq!(columns, vec!["a2", "b2", "c2"]);
2322        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2323            +----+----+-----+
2324            | a2 | b2 | c2  |
2325            +----+----+-----+
2326            | 10 | 10 | 100 |
2327            | 12 | 10 | 40  |
2328            +----+----+-----+
2329            "#));
2330
2331        assert_join_metrics!(metrics, 2);
2332
2333        Ok(())
2334    }
2335
2336    #[rstest]
2337    #[tokio::test]
2338    async fn join_left_mark_with_filter(
2339        #[values(1, 2, 16)] batch_size: usize,
2340    ) -> Result<()> {
2341        let task_ctx = new_task_ctx(batch_size);
2342        let left = build_left_table();
2343        let right = build_right_table();
2344
2345        let filter = prepare_join_filter();
2346        let (columns, batches, metrics) = multi_partitioned_join_collect(
2347            left,
2348            right,
2349            &JoinType::LeftMark,
2350            Some(filter),
2351            task_ctx,
2352        )
2353        .await?;
2354        assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
2355        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2356            +----+----+-----+-------+
2357            | a1 | b1 | c1  | mark  |
2358            +----+----+-----+-------+
2359            | 11 | 8  | 110 | false |
2360            | 5  | 5  | 50  | true  |
2361            | 9  | 8  | 90  | false |
2362            +----+----+-----+-------+
2363            "#));
2364
2365        assert_join_metrics!(metrics, 3);
2366
2367        Ok(())
2368    }
2369
2370    #[rstest]
2371    #[tokio::test]
2372    async fn join_right_mark_with_filter(
2373        #[values(1, 2, 16)] batch_size: usize,
2374    ) -> Result<()> {
2375        let task_ctx = new_task_ctx(batch_size);
2376        let left = build_left_table();
2377        let right = build_right_table();
2378
2379        let filter = prepare_join_filter();
2380        let (columns, batches, metrics) = multi_partitioned_join_collect(
2381            left,
2382            right,
2383            &JoinType::RightMark,
2384            Some(filter),
2385            task_ctx,
2386        )
2387        .await?;
2388        assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
2389
2390        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2391            +----+----+-----+-------+
2392            | a2 | b2 | c2  | mark  |
2393            +----+----+-----+-------+
2394            | 10 | 10 | 100 | false |
2395            | 12 | 10 | 40  | false |
2396            | 2  | 2  | 80  | true  |
2397            +----+----+-----+-------+
2398            "#));
2399
2400        assert_join_metrics!(metrics, 3);
2401
2402        Ok(())
2403    }
2404
2405    #[tokio::test]
2406    async fn test_overallocation() -> Result<()> {
2407        let left = build_table(
2408            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2409            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2410            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2411            None,
2412            Vec::new(),
2413        );
2414        let right = build_table(
2415            ("a2", &vec![10, 11]),
2416            ("b2", &vec![12, 13]),
2417            ("c2", &vec![14, 15]),
2418            None,
2419            Vec::new(),
2420        );
2421        let filter = prepare_join_filter();
2422
2423        let join_types = vec![
2424            JoinType::Inner,
2425            JoinType::Left,
2426            JoinType::Right,
2427            JoinType::Full,
2428            JoinType::LeftSemi,
2429            JoinType::LeftAnti,
2430            JoinType::LeftMark,
2431            JoinType::RightSemi,
2432            JoinType::RightAnti,
2433            JoinType::RightMark,
2434        ];
2435
2436        for join_type in join_types {
2437            let runtime = RuntimeEnvBuilder::new()
2438                .with_memory_limit(100, 1.0)
2439                .build_arc()?;
2440            let task_ctx = TaskContext::default().with_runtime(runtime);
2441            let task_ctx = Arc::new(task_ctx);
2442
2443            let err = multi_partitioned_join_collect(
2444                Arc::clone(&left),
2445                Arc::clone(&right),
2446                &join_type,
2447                Some(filter.clone()),
2448                task_ctx,
2449            )
2450            .await
2451            .unwrap_err();
2452
2453            assert_contains!(
2454                err.to_string(),
2455                "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n  NestedLoopJoinLoad[0]"
2456            );
2457        }
2458
2459        Ok(())
2460    }
2461
2462    /// Returns the column names on the schema
2463    fn columns(schema: &Schema) -> Vec<String> {
2464        schema.fields().iter().map(|f| f.name().clone()).collect()
2465    }
2466}