datafusion_physical_plan/joins/
nested_loop_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates).
19
20use std::any::Any;
21use std::fmt::Formatter;
22use std::ops::{BitOr, ControlFlow};
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::task::Poll;
26
27use super::utils::{
28    asymmetric_join_output_partitioning, need_produce_result_in_final,
29    reorder_output_after_swap, swap_join_projection,
30};
31use crate::common::can_project;
32use crate::execution_plan::{boundedness_from_children, EmissionType};
33use crate::joins::utils::{
34    build_join_schema, check_join_is_valid, estimate_join_statistics,
35    need_produce_right_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
36    OnceAsync, OnceFut,
37};
38use crate::joins::SharedBitmapBuilder;
39use crate::metrics::{
40    Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, MetricsSet, RatioMetrics,
41};
42use crate::projection::{
43    try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData,
44    ProjectionExec,
45};
46use crate::{
47    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
48    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
49};
50
51use arrow::array::{
52    new_null_array, Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions,
53    UInt64Array,
54};
55use arrow::buffer::BooleanBuffer;
56use arrow::compute::{
57    concat_batches, filter, filter_record_batch, not, take, BatchCoalescer,
58};
59use arrow::datatypes::{Schema, SchemaRef};
60use arrow::record_batch::RecordBatch;
61use arrow_schema::DataType;
62use datafusion_common::cast::as_boolean_array;
63use datafusion_common::{
64    arrow_err, internal_datafusion_err, internal_err, project_schema,
65    unwrap_or_internal_err, DataFusionError, JoinSide, Result, ScalarValue, Statistics,
66};
67use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
68use datafusion_execution::TaskContext;
69use datafusion_expr::JoinType;
70use datafusion_physical_expr::equivalence::{
71    join_equivalence_properties, ProjectionMapping,
72};
73
74use futures::{Stream, StreamExt, TryStreamExt};
75use log::debug;
76use parking_lot::Mutex;
77
78#[allow(rustdoc::private_intra_doc_links)]
79/// NestedLoopJoinExec is a build-probe join operator designed for joins that
80/// do not have equijoin keys in their `ON` clause.
81///
82/// # Execution Flow
83///
84/// ```text
85///                                                Incoming right batch
86///                Left Side Buffered Batches
87///                       ┌───────────┐              ┌───────────────┐
88///                       │ ┌───────┐ │              │               │
89///                       │ │       │ │              │               │
90///  Current Left Row ───▶│ ├───────├─┤──────────┐   │               │
91///                       │ │       │ │          │   └───────────────┘
92///                       │ │       │ │          │           │
93///                       │ │       │ │          │           │
94///                       │ └───────┘ │          │           │
95///                       │ ┌───────┐ │          │           │
96///                       │ │       │ │          │     ┌─────┘
97///                       │ │       │ │          │     │
98///                       │ │       │ │          │     │
99///                       │ │       │ │          │     │
100///                       │ │       │ │          │     │
101///                       │ └───────┘ │          ▼     ▼
102///                       │   ......  │  ┌──────────────────────┐
103///                       │           │  │X (Cartesian Product) │
104///                       │           │  └──────────┬───────────┘
105///                       └───────────┘             │
106///                                                 │
107///                                                 ▼
108///                                      ┌───────┬───────────────┐
109///                                      │       │               │
110///                                      │       │               │
111///                                      │       │               │
112///                                      └───────┴───────────────┘
113///                                        Intermediate Batch
114///                                  (For join predicate evaluation)
115/// ```
116///
117/// The execution follows a two-phase design:
118///
119/// ## 1. Buffering Left Input
120/// - The operator eagerly buffers all left-side input batches into memory,
121///   util a memory limit is reached.
122///   Currently, an out-of-memory error will be thrown if all the left-side input batches
123///   cannot fit into memory at once.
124///   In the future, it's possible to make this case finish execution. (see
125///   'Memory-limited Execution' section)
126/// - The rationale for buffering the left side is that scanning the right side
127///   can be expensive (e.g., decoding Parquet files), so buffering more left
128///   rows reduces the number of right-side scan passes required.
129///
130/// ## 2. Probing Right Input
131/// - Right-side input is streamed batch by batch.
132/// - For each right-side batch:
133///   - It evaluates the join filter against the full buffered left input.
134///     This results in a Cartesian product between the right batch and each
135///     left row -- with the join predicate/filter applied -- for each inner
136///     loop iteration.
137///   - Matched results are accumulated into an output buffer. (see more in
138///     `Output Buffering Strategy` section)
139/// - This process continues until all right-side input is consumed.
140///
141/// # Producing unmatched build-side data
142/// - For special join types like left/full joins, it's required to also output
143///   unmatched pairs. During execution, bitmaps are kept for both left and right
144///   sides of the input; they'll be handled by dedicated states in `NLJStream`.
145/// - The final output of the left side unmatched rows is handled by a single
146///   partition for simplicity, since it only counts a small portion of the
147///   execution time. (e.g. if probe side has 10k rows, the final output of
148///   unmatched build side only roughly counts for 1/10k of the total time)
149///
150/// # Output Buffering Strategy
151/// The operator uses an intermediate output buffer to accumulate results. Once
152/// the output threshold is reached (currently set to the same value as
153/// `batch_size` in the configuration), the results will be eagerly output.
154///
155/// # Extra Notes
156/// - The operator always considers the **left** side as the build (buffered) side.
157///   Therefore, the physical optimizer should assign the smaller input to the left.
158/// - The design try to minimize the intermediate data size to approximately
159///   1 batch, for better cache locality and memory efficiency.
160///
161/// # TODO: Memory-limited Execution
162/// If the memory budget is exceeded during left-side buffering, fallback
163/// strategies such as streaming left batches and re-scanning the right side
164/// may be implemented in the future.
165///
166/// Tracking issue: <https://github.com/apache/datafusion/issues/15760>
167///
168/// # Clone / Shared State
169/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
170/// loading of the left side with the processing in each output stream.
171/// Therefore it can not be [`Clone`]
172#[derive(Debug)]
173pub struct NestedLoopJoinExec {
174    /// left side
175    pub(crate) left: Arc<dyn ExecutionPlan>,
176    /// right side
177    pub(crate) right: Arc<dyn ExecutionPlan>,
178    /// Filters which are applied while finding matching rows
179    pub(crate) filter: Option<JoinFilter>,
180    /// How the join is performed
181    pub(crate) join_type: JoinType,
182    /// The full concatenated schema of left and right children should be distinct from
183    /// the output schema of the operator
184    join_schema: SchemaRef,
185    /// Future that consumes left input and buffers it in memory
186    ///
187    /// This structure is *shared* across all output streams.
188    ///
189    /// Each output stream waits on the `OnceAsync` to signal the completion of
190    /// the build(left) side data, and buffer them all for later joining.
191    build_side_data: OnceAsync<JoinLeftData>,
192    /// Information of index and left / right placement of columns
193    column_indices: Vec<ColumnIndex>,
194    /// Projection to apply to the output of the join
195    projection: Option<Vec<usize>>,
196
197    /// Execution metrics
198    metrics: ExecutionPlanMetricsSet,
199    /// Cache holding plan properties like equivalences, output partitioning etc.
200    cache: PlanProperties,
201}
202
203impl NestedLoopJoinExec {
204    /// Try to create a new [`NestedLoopJoinExec`]
205    pub fn try_new(
206        left: Arc<dyn ExecutionPlan>,
207        right: Arc<dyn ExecutionPlan>,
208        filter: Option<JoinFilter>,
209        join_type: &JoinType,
210        projection: Option<Vec<usize>>,
211    ) -> Result<Self> {
212        let left_schema = left.schema();
213        let right_schema = right.schema();
214        check_join_is_valid(&left_schema, &right_schema, &[])?;
215        let (join_schema, column_indices) =
216            build_join_schema(&left_schema, &right_schema, join_type);
217        let join_schema = Arc::new(join_schema);
218        let cache = Self::compute_properties(
219            &left,
220            &right,
221            Arc::clone(&join_schema),
222            *join_type,
223            projection.as_ref(),
224        )?;
225
226        Ok(NestedLoopJoinExec {
227            left,
228            right,
229            filter,
230            join_type: *join_type,
231            join_schema,
232            build_side_data: Default::default(),
233            column_indices,
234            projection,
235            metrics: Default::default(),
236            cache,
237        })
238    }
239
240    /// left side
241    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
242        &self.left
243    }
244
245    /// right side
246    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
247        &self.right
248    }
249
250    /// Filters applied before join output
251    pub fn filter(&self) -> Option<&JoinFilter> {
252        self.filter.as_ref()
253    }
254
255    /// How the join is performed
256    pub fn join_type(&self) -> &JoinType {
257        &self.join_type
258    }
259
260    pub fn projection(&self) -> Option<&Vec<usize>> {
261        self.projection.as_ref()
262    }
263
264    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
265    fn compute_properties(
266        left: &Arc<dyn ExecutionPlan>,
267        right: &Arc<dyn ExecutionPlan>,
268        schema: SchemaRef,
269        join_type: JoinType,
270        projection: Option<&Vec<usize>>,
271    ) -> Result<PlanProperties> {
272        // Calculate equivalence properties:
273        let mut eq_properties = join_equivalence_properties(
274            left.equivalence_properties().clone(),
275            right.equivalence_properties().clone(),
276            &join_type,
277            Arc::clone(&schema),
278            &Self::maintains_input_order(join_type),
279            None,
280            // No on columns in nested loop join
281            &[],
282        )?;
283
284        let mut output_partitioning =
285            asymmetric_join_output_partitioning(left, right, &join_type)?;
286
287        let emission_type = if left.boundedness().is_unbounded() {
288            EmissionType::Final
289        } else if right.pipeline_behavior() == EmissionType::Incremental {
290            match join_type {
291                // If we only need to generate matched rows from the probe side,
292                // we can emit rows incrementally.
293                JoinType::Inner
294                | JoinType::LeftSemi
295                | JoinType::RightSemi
296                | JoinType::Right
297                | JoinType::RightAnti
298                | JoinType::RightMark => EmissionType::Incremental,
299                // If we need to generate unmatched rows from the *build side*,
300                // we need to emit them at the end.
301                JoinType::Left
302                | JoinType::LeftAnti
303                | JoinType::LeftMark
304                | JoinType::Full => EmissionType::Both,
305            }
306        } else {
307            right.pipeline_behavior()
308        };
309
310        if let Some(projection) = projection {
311            // construct a map from the input expressions to the output expression of the Projection
312            let projection_mapping =
313                ProjectionMapping::from_indices(projection, &schema)?;
314            let out_schema = project_schema(&schema, Some(projection))?;
315            output_partitioning =
316                output_partitioning.project(&projection_mapping, &eq_properties);
317            eq_properties = eq_properties.project(&projection_mapping, out_schema);
318        }
319
320        Ok(PlanProperties::new(
321            eq_properties,
322            output_partitioning,
323            emission_type,
324            boundedness_from_children([left, right]),
325        ))
326    }
327
328    /// This join implementation does not preserve the input order of either side.
329    fn maintains_input_order(_join_type: JoinType) -> Vec<bool> {
330        vec![false, false]
331    }
332
333    pub fn contains_projection(&self) -> bool {
334        self.projection.is_some()
335    }
336
337    pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
338        // check if the projection is valid
339        can_project(&self.schema(), projection.as_ref())?;
340        let projection = match projection {
341            Some(projection) => match &self.projection {
342                Some(p) => Some(projection.iter().map(|i| p[*i]).collect()),
343                None => Some(projection),
344            },
345            None => None,
346        };
347        Self::try_new(
348            Arc::clone(&self.left),
349            Arc::clone(&self.right),
350            self.filter.clone(),
351            &self.join_type,
352            projection,
353        )
354    }
355
356    /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left
357    /// and right inputs swapped.
358    ///
359    /// # Notes:
360    ///
361    /// This function should be called BEFORE inserting any repartitioning
362    /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`]
363    /// for more details.
364    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
365        let left = self.left();
366        let right = self.right();
367        let new_join = NestedLoopJoinExec::try_new(
368            Arc::clone(right),
369            Arc::clone(left),
370            self.filter().map(JoinFilter::swap),
371            &self.join_type().swap(),
372            swap_join_projection(
373                left.schema().fields().len(),
374                right.schema().fields().len(),
375                self.projection.as_ref(),
376                self.join_type(),
377            ),
378        )?;
379
380        // For Semi/Anti joins, swap result will produce same output schema,
381        // no need to wrap them into additional projection
382        let plan: Arc<dyn ExecutionPlan> = if matches!(
383            self.join_type(),
384            JoinType::LeftSemi
385                | JoinType::RightSemi
386                | JoinType::LeftAnti
387                | JoinType::RightAnti
388                | JoinType::LeftMark
389                | JoinType::RightMark
390        ) || self.projection.is_some()
391        {
392            Arc::new(new_join)
393        } else {
394            reorder_output_after_swap(
395                Arc::new(new_join),
396                &self.left().schema(),
397                &self.right().schema(),
398            )?
399        };
400
401        Ok(plan)
402    }
403}
404
405impl DisplayAs for NestedLoopJoinExec {
406    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
407        match t {
408            DisplayFormatType::Default | DisplayFormatType::Verbose => {
409                let display_filter = self.filter.as_ref().map_or_else(
410                    || "".to_string(),
411                    |f| format!(", filter={}", f.expression()),
412                );
413                let display_projections = if self.contains_projection() {
414                    format!(
415                        ", projection=[{}]",
416                        self.projection
417                            .as_ref()
418                            .unwrap()
419                            .iter()
420                            .map(|index| format!(
421                                "{}@{}",
422                                self.join_schema.fields().get(*index).unwrap().name(),
423                                index
424                            ))
425                            .collect::<Vec<_>>()
426                            .join(", ")
427                    )
428                } else {
429                    "".to_string()
430                };
431                write!(
432                    f,
433                    "NestedLoopJoinExec: join_type={:?}{}{}",
434                    self.join_type, display_filter, display_projections
435                )
436            }
437            DisplayFormatType::TreeRender => {
438                if *self.join_type() != JoinType::Inner {
439                    writeln!(f, "join_type={:?}", self.join_type)
440                } else {
441                    Ok(())
442                }
443            }
444        }
445    }
446}
447
448impl ExecutionPlan for NestedLoopJoinExec {
449    fn name(&self) -> &'static str {
450        "NestedLoopJoinExec"
451    }
452
453    fn as_any(&self) -> &dyn Any {
454        self
455    }
456
457    fn properties(&self) -> &PlanProperties {
458        &self.cache
459    }
460
461    fn required_input_distribution(&self) -> Vec<Distribution> {
462        vec![
463            Distribution::SinglePartition,
464            Distribution::UnspecifiedDistribution,
465        ]
466    }
467
468    fn maintains_input_order(&self) -> Vec<bool> {
469        Self::maintains_input_order(self.join_type)
470    }
471
472    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
473        vec![&self.left, &self.right]
474    }
475
476    fn with_new_children(
477        self: Arc<Self>,
478        children: Vec<Arc<dyn ExecutionPlan>>,
479    ) -> Result<Arc<dyn ExecutionPlan>> {
480        Ok(Arc::new(NestedLoopJoinExec::try_new(
481            Arc::clone(&children[0]),
482            Arc::clone(&children[1]),
483            self.filter.clone(),
484            &self.join_type,
485            self.projection.clone(),
486        )?))
487    }
488
489    fn execute(
490        &self,
491        partition: usize,
492        context: Arc<TaskContext>,
493    ) -> Result<SendableRecordBatchStream> {
494        if self.left.output_partitioning().partition_count() != 1 {
495            return internal_err!(
496                "Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
497                 consider using CoalescePartitionsExec or the EnforceDistribution rule"
498            );
499        }
500
501        let metrics = NestedLoopJoinMetrics::new(&self.metrics, partition);
502
503        // Initialization reservation for load of inner table
504        let load_reservation =
505            MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
506                .register(context.memory_pool());
507
508        let build_side_data = self.build_side_data.try_once(|| {
509            let stream = self.left.execute(0, Arc::clone(&context))?;
510
511            Ok(collect_left_input(
512                stream,
513                metrics.join_metrics.clone(),
514                load_reservation,
515                need_produce_result_in_final(self.join_type),
516                self.right().output_partitioning().partition_count(),
517            ))
518        })?;
519
520        let batch_size = context.session_config().batch_size();
521
522        let probe_side_data = self.right.execute(partition, context)?;
523
524        // update column indices to reflect the projection
525        let column_indices_after_projection = match &self.projection {
526            Some(projection) => projection
527                .iter()
528                .map(|i| self.column_indices[*i].clone())
529                .collect(),
530            None => self.column_indices.clone(),
531        };
532
533        Ok(Box::pin(NestedLoopJoinStream::new(
534            self.schema(),
535            self.filter.clone(),
536            self.join_type,
537            probe_side_data,
538            build_side_data,
539            column_indices_after_projection,
540            metrics,
541            batch_size,
542        )))
543    }
544
545    fn metrics(&self) -> Option<MetricsSet> {
546        Some(self.metrics.clone_inner())
547    }
548
549    fn statistics(&self) -> Result<Statistics> {
550        self.partition_statistics(None)
551    }
552
553    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
554        if partition.is_some() {
555            return Ok(Statistics::new_unknown(&self.schema()));
556        }
557        estimate_join_statistics(
558            self.left.partition_statistics(None)?,
559            self.right.partition_statistics(None)?,
560            vec![],
561            &self.join_type,
562            &self.schema(),
563        )
564    }
565
566    /// Tries to push `projection` down through `nested_loop_join`. If possible, performs the
567    /// pushdown and returns a new [`NestedLoopJoinExec`] as the top plan which has projections
568    /// as its children. Otherwise, returns `None`.
569    fn try_swapping_with_projection(
570        &self,
571        projection: &ProjectionExec,
572    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
573        // TODO: currently if there is projection in NestedLoopJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later.
574        if self.contains_projection() {
575            return Ok(None);
576        }
577
578        if let Some(JoinData {
579            projected_left_child,
580            projected_right_child,
581            join_filter,
582            ..
583        }) = try_pushdown_through_join(
584            projection,
585            self.left(),
586            self.right(),
587            &[],
588            self.schema(),
589            self.filter(),
590        )? {
591            Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
592                Arc::new(projected_left_child),
593                Arc::new(projected_right_child),
594                join_filter,
595                self.join_type(),
596                // Returned early if projection is not None
597                None,
598            )?)))
599        } else {
600            try_embed_projection(projection, self)
601        }
602    }
603}
604
605impl EmbeddedProjection for NestedLoopJoinExec {
606    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
607        self.with_projection(projection)
608    }
609}
610
611/// Left (build-side) data
612pub(crate) struct JoinLeftData {
613    /// Build-side data collected to single batch
614    batch: RecordBatch,
615    /// Shared bitmap builder for visited left indices
616    bitmap: SharedBitmapBuilder,
617    /// Counter of running probe-threads, potentially able to update `bitmap`
618    probe_threads_counter: AtomicUsize,
619    /// Memory reservation for tracking batch and bitmap
620    /// Cleared on `JoinLeftData` drop
621    /// reservation is cleared on Drop
622    #[expect(dead_code)]
623    reservation: MemoryReservation,
624}
625
626impl JoinLeftData {
627    pub(crate) fn new(
628        batch: RecordBatch,
629        bitmap: SharedBitmapBuilder,
630        probe_threads_counter: AtomicUsize,
631        reservation: MemoryReservation,
632    ) -> Self {
633        Self {
634            batch,
635            bitmap,
636            probe_threads_counter,
637            reservation,
638        }
639    }
640
641    pub(crate) fn batch(&self) -> &RecordBatch {
642        &self.batch
643    }
644
645    pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder {
646        &self.bitmap
647    }
648
649    /// Decrements counter of running threads, and returns `true`
650    /// if caller is the last running thread
651    pub(crate) fn report_probe_completed(&self) -> bool {
652        self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
653    }
654}
655
656/// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it
657async fn collect_left_input(
658    stream: SendableRecordBatchStream,
659    join_metrics: BuildProbeJoinMetrics,
660    reservation: MemoryReservation,
661    with_visited_left_side: bool,
662    probe_threads_count: usize,
663) -> Result<JoinLeftData> {
664    let schema = stream.schema();
665
666    // Load all batches and count the rows
667    let (batches, metrics, mut reservation) = stream
668        .try_fold(
669            (Vec::new(), join_metrics, reservation),
670            |(mut batches, metrics, mut reservation), batch| async {
671                let batch_size = batch.get_array_memory_size();
672                // Reserve memory for incoming batch
673                reservation.try_grow(batch_size)?;
674                // Update metrics
675                metrics.build_mem_used.add(batch_size);
676                metrics.build_input_batches.add(1);
677                metrics.build_input_rows.add(batch.num_rows());
678                // Push batch to output
679                batches.push(batch);
680                Ok((batches, metrics, reservation))
681            },
682        )
683        .await?;
684
685    let merged_batch = concat_batches(&schema, &batches)?;
686
687    // Reserve memory for visited_left_side bitmap if required by join type
688    let visited_left_side = if with_visited_left_side {
689        let n_rows = merged_batch.num_rows();
690        let buffer_size = n_rows.div_ceil(8);
691        reservation.try_grow(buffer_size)?;
692        metrics.build_mem_used.add(buffer_size);
693
694        let mut buffer = BooleanBufferBuilder::new(n_rows);
695        buffer.append_n(n_rows, false);
696        buffer
697    } else {
698        BooleanBufferBuilder::new(0)
699    };
700
701    Ok(JoinLeftData::new(
702        merged_batch,
703        Mutex::new(visited_left_side),
704        AtomicUsize::new(probe_threads_count),
705        reservation,
706    ))
707}
708
709/// States for join processing. See `poll_next()` comment for more details about
710/// state transitions.
711#[derive(Debug, Clone, Copy)]
712enum NLJState {
713    BufferingLeft,
714    FetchingRight,
715    ProbeRight,
716    EmitRightUnmatched,
717    EmitLeftUnmatched,
718    Done,
719}
720pub(crate) struct NestedLoopJoinStream {
721    // ========================================================================
722    // PROPERTIES:
723    // Operator's properties that remain constant
724    //
725    // Note: The implementation uses the terms left/build-side table and
726    // right/probe-side table interchangeably. Treating the left side as the
727    // build side is a convention in DataFusion: the planner always tries to
728    // swap the smaller table to the left side.
729    // ========================================================================
730    /// Output schema
731    pub(crate) output_schema: Arc<Schema>,
732    /// join filter
733    pub(crate) join_filter: Option<JoinFilter>,
734    /// type of the join
735    pub(crate) join_type: JoinType,
736    /// the probe-side(right) table data of the nested loop join
737    pub(crate) right_data: SendableRecordBatchStream,
738    /// the build-side table data of the nested loop join
739    pub(crate) left_data: OnceFut<JoinLeftData>,
740    /// Projection to construct the output schema from the left and right tables.
741    /// Example:
742    /// - output_schema: ['a', 'c']
743    /// - left_schema: ['a', 'b']
744    /// - right_schema: ['c']
745    ///
746    /// The column indices would be [(left, 0), (right, 0)] -- taking the left
747    /// 0th column and right 0th column can construct the output schema.
748    ///
749    /// Note there are other columns ('b' in the example) still kept after
750    /// projection pushdown; this is because they might be used to evaluate
751    /// the join filter (e.g., `JOIN ON (b+c)>0`).
752    pub(crate) column_indices: Vec<ColumnIndex>,
753    /// Join execution metrics
754    pub(crate) metrics: NestedLoopJoinMetrics,
755
756    /// `batch_size` from configuration
757    batch_size: usize,
758
759    /// See comments in [`need_produce_right_in_final`] for more detail
760    should_track_unmatched_right: bool,
761
762    // ========================================================================
763    // STATE FLAGS/BUFFERS:
764    // Fields that hold intermediate data/flags during execution
765    // ========================================================================
766    /// State Tracking
767    state: NLJState,
768    /// Output buffer holds the join result to output. It will emit eagerly when
769    /// the threshold is reached.
770    output_buffer: Box<BatchCoalescer>,
771    /// See comments in [`NLJState::Done`] for its purpose
772    handled_empty_output: bool,
773
774    // Buffer(left) side
775    // -----------------
776    /// The current buffered left data to join
777    buffered_left_data: Option<Arc<JoinLeftData>>,
778    /// Index into the left buffered batch. Used in `ProbeRight` state
779    left_probe_idx: usize,
780    /// Index into the left buffered batch. Used in `EmitLeftUnmatched` state
781    left_emit_idx: usize,
782    /// Should we go back to `BufferingLeft` state again after `EmitLeftUnmatched`
783    /// state is over.
784    left_exhausted: bool,
785    /// If we can buffer all left data in one pass
786    /// TODO(now): this is for the (unimplemented) memory-limited execution
787    #[allow(dead_code)]
788    left_buffered_in_one_pass: bool,
789
790    // Probe(right) side
791    // -----------------
792    /// The current probe batch to process
793    current_right_batch: Option<RecordBatch>,
794    // For right join, keep track of matched rows in `current_right_batch`
795    // Constructed when fetching each new incoming right batch in `FetchingRight` state.
796    current_right_batch_matched: Option<BooleanArray>,
797}
798
799pub(crate) struct NestedLoopJoinMetrics {
800    /// Join execution metrics
801    pub(crate) join_metrics: BuildProbeJoinMetrics,
802    /// Selectivity of the join: output_rows / (left_rows * right_rows)
803    pub(crate) selectivity: RatioMetrics,
804}
805
806impl NestedLoopJoinMetrics {
807    pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
808        Self {
809            join_metrics: BuildProbeJoinMetrics::new(partition, metrics),
810            selectivity: MetricBuilder::new(metrics)
811                .with_type(MetricType::SUMMARY)
812                .ratio_metrics("selectivity", partition),
813        }
814    }
815}
816
817impl Stream for NestedLoopJoinStream {
818    type Item = Result<RecordBatch>;
819
820    /// See the comments [`NestedLoopJoinExec`] for high-level design ideas.
821    ///
822    /// # Implementation
823    ///
824    /// This function is the entry point of NLJ operator's state machine
825    /// transitions. The rough state transition graph is as follow, for more
826    /// details see the comment in each state's matching arm.
827    ///
828    /// ============================
829    /// State transition graph:
830    /// ============================
831    ///
832    /// (start) --> BufferingLeft
833    /// ----------------------------
834    /// BufferingLeft → FetchingRight
835    ///
836    /// FetchingRight → ProbeRight (if right batch available)
837    /// FetchingRight → EmitLeftUnmatched (if right exhausted)
838    ///
839    /// ProbeRight → ProbeRight (next left row or after yielding output)
840    /// ProbeRight → EmitRightUnmatched (for special join types like right join)
841    /// ProbeRight → FetchingRight (done with the current right batch)
842    ///
843    /// EmitRightUnmatched → FetchingRight
844    ///
845    /// EmitLeftUnmatched → EmitLeftUnmatched (only process 1 chunk for each
846    /// iteration)
847    /// EmitLeftUnmatched → Done (if finished)
848    /// ----------------------------
849    /// Done → (end)
850    fn poll_next(
851        mut self: std::pin::Pin<&mut Self>,
852        cx: &mut std::task::Context<'_>,
853    ) -> Poll<Option<Self::Item>> {
854        loop {
855            match self.state {
856                // # NLJState transitions
857                // --> FetchingRight
858                // This state will prepare the left side batches, next state
859                // `FetchingRight` is responsible for preparing a single probe
860                // side batch, before start joining.
861                NLJState::BufferingLeft => {
862                    debug!("[NLJState] Entering: {:?}", self.state);
863                    // inside `collect_left_input` (the routine to buffer build
864                    // -side batches), related metrics except build time will be
865                    // updated.
866                    // stop on drop
867                    let build_metric = self.metrics.join_metrics.build_time.clone();
868                    let _build_timer = build_metric.timer();
869
870                    match self.handle_buffering_left(cx) {
871                        ControlFlow::Continue(()) => continue,
872                        ControlFlow::Break(poll) => return poll,
873                    }
874                }
875
876                // # NLJState transitions:
877                // 1. --> ProbeRight
878                //    Start processing the join for the newly fetched right
879                //    batch.
880                // 2. --> EmitLeftUnmatched: When the right side input is exhausted, (maybe) emit
881                //    unmatched left side rows.
882                //
883                // After fetching a new batch from the right side, it will
884                // process all rows from the buffered left data:
885                // ```text
886                // for batch in right_side:
887                //     for row in left_buffer:
888                //         join(batch, row)
889                // ```
890                // Note: the implementation does this step incrementally,
891                // instead of materializing all intermediate Cartesian products
892                // at once in memory.
893                //
894                // So after the right side input is exhausted, the join phase
895                // for the current buffered left data is finished. We can go to
896                // the next `EmitLeftUnmatched` phase to check if there is any
897                // special handling (e.g., in cases like left join).
898                NLJState::FetchingRight => {
899                    debug!("[NLJState] Entering: {:?}", self.state);
900                    // stop on drop
901                    let join_metric = self.metrics.join_metrics.join_time.clone();
902                    let _join_timer = join_metric.timer();
903
904                    match self.handle_fetching_right(cx) {
905                        ControlFlow::Continue(()) => continue,
906                        ControlFlow::Break(poll) => return poll,
907                    }
908                }
909
910                // NLJState transitions:
911                // 1. --> ProbeRight(1)
912                //    If we have already buffered enough output to yield, it
913                //    will first give back control to the parent state machine,
914                //    then resume at the same place.
915                // 2. --> ProbeRight(2)
916                //    After probing one right batch, and evaluating the
917                //    join filter on (left-row x right-batch), it will advance
918                //    to the next left row, then re-enter the current state and
919                //    continue joining.
920                // 3. --> FetchRight
921                //    After it has done with the current right batch (to join
922                //    with all rows in the left buffer), it will go to
923                //    FetchRight state to check what to do next.
924                NLJState::ProbeRight => {
925                    debug!("[NLJState] Entering: {:?}", self.state);
926
927                    // stop on drop
928                    let join_metric = self.metrics.join_metrics.join_time.clone();
929                    let _join_timer = join_metric.timer();
930
931                    match self.handle_probe_right() {
932                        ControlFlow::Continue(()) => continue,
933                        ControlFlow::Break(poll) => {
934                            return self.metrics.join_metrics.baseline.record_poll(poll)
935                        }
936                    }
937                }
938
939                // In the `current_right_batch_matched` bitmap, all trues mean
940                // it has been output by the join. In this state we have to
941                // output unmatched rows for current right batch (with null
942                // padding for left relation)
943                // Precondition: we have checked the join type so that it's
944                // possible to output right unmatched (e.g. it's right join)
945                NLJState::EmitRightUnmatched => {
946                    debug!("[NLJState] Entering: {:?}", self.state);
947
948                    // stop on drop
949                    let join_metric = self.metrics.join_metrics.join_time.clone();
950                    let _join_timer = join_metric.timer();
951
952                    match self.handle_emit_right_unmatched() {
953                        ControlFlow::Continue(()) => continue,
954                        ControlFlow::Break(poll) => {
955                            return self.metrics.join_metrics.baseline.record_poll(poll)
956                        }
957                    }
958                }
959
960                // NLJState transitions:
961                // 1. --> EmitLeftUnmatched(1)
962                //    If we have already buffered enough output to yield, it
963                //    will first give back control to the parent state machine,
964                //    then resume at the same place.
965                // 2. --> EmitLeftUnmatched(2)
966                //    After processing some unmatched rows, it will re-enter
967                //    the same state, to check if there are any more final
968                //    results to output.
969                // 3. --> Done
970                //    It has processed all data, go to the final state and ready
971                //    to exit.
972                //
973                // TODO: For memory-limited case, go back to `BufferingLeft`
974                // state again.
975                NLJState::EmitLeftUnmatched => {
976                    debug!("[NLJState] Entering: {:?}", self.state);
977
978                    // stop on drop
979                    let join_metric = self.metrics.join_metrics.join_time.clone();
980                    let _join_timer = join_metric.timer();
981
982                    match self.handle_emit_left_unmatched() {
983                        ControlFlow::Continue(()) => continue,
984                        ControlFlow::Break(poll) => {
985                            return self.metrics.join_metrics.baseline.record_poll(poll)
986                        }
987                    }
988                }
989
990                // The final state and the exit point
991                NLJState::Done => {
992                    debug!("[NLJState] Entering: {:?}", self.state);
993
994                    // stop on drop
995                    let join_metric = self.metrics.join_metrics.join_time.clone();
996                    let _join_timer = join_metric.timer();
997                    // counting it in join timer due to there might be some
998                    // final resout batches to output in this state
999
1000                    let poll = self.handle_done();
1001                    return self.metrics.join_metrics.baseline.record_poll(poll);
1002                }
1003            }
1004        }
1005    }
1006}
1007
1008impl RecordBatchStream for NestedLoopJoinStream {
1009    fn schema(&self) -> SchemaRef {
1010        Arc::clone(&self.output_schema)
1011    }
1012}
1013
1014impl NestedLoopJoinStream {
1015    #[allow(clippy::too_many_arguments)]
1016    pub(crate) fn new(
1017        schema: Arc<Schema>,
1018        filter: Option<JoinFilter>,
1019        join_type: JoinType,
1020        right_data: SendableRecordBatchStream,
1021        left_data: OnceFut<JoinLeftData>,
1022        column_indices: Vec<ColumnIndex>,
1023        metrics: NestedLoopJoinMetrics,
1024        batch_size: usize,
1025    ) -> Self {
1026        Self {
1027            output_schema: Arc::clone(&schema),
1028            join_filter: filter,
1029            join_type,
1030            right_data,
1031            column_indices,
1032            left_data,
1033            metrics,
1034            buffered_left_data: None,
1035            output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)),
1036            batch_size,
1037            current_right_batch: None,
1038            current_right_batch_matched: None,
1039            state: NLJState::BufferingLeft,
1040            left_probe_idx: 0,
1041            left_emit_idx: 0,
1042            left_exhausted: false,
1043            left_buffered_in_one_pass: true,
1044            handled_empty_output: false,
1045            should_track_unmatched_right: need_produce_right_in_final(join_type),
1046        }
1047    }
1048
1049    // ==== State handler functions ====
1050
1051    /// Handle BufferingLeft state - prepare left side batches
1052    fn handle_buffering_left(
1053        &mut self,
1054        cx: &mut std::task::Context<'_>,
1055    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1056        match self.left_data.get_shared(cx) {
1057            Poll::Ready(Ok(left_data)) => {
1058                self.buffered_left_data = Some(left_data);
1059                // TODO: implement memory-limited case
1060                self.left_exhausted = true;
1061                self.state = NLJState::FetchingRight;
1062                // Continue to next state immediately
1063                ControlFlow::Continue(())
1064            }
1065            Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1066            Poll::Pending => ControlFlow::Break(Poll::Pending),
1067        }
1068    }
1069
1070    /// Handle FetchingRight state - fetch next right batch and prepare for processing
1071    fn handle_fetching_right(
1072        &mut self,
1073        cx: &mut std::task::Context<'_>,
1074    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1075        match self.right_data.poll_next_unpin(cx) {
1076            Poll::Ready(result) => match result {
1077                Some(Ok(right_batch)) => {
1078                    // Update metrics
1079                    let right_batch_size = right_batch.num_rows();
1080                    self.metrics.join_metrics.input_rows.add(right_batch_size);
1081                    self.metrics.join_metrics.input_batches.add(1);
1082
1083                    // Skip the empty batch
1084                    if right_batch_size == 0 {
1085                        return ControlFlow::Continue(());
1086                    }
1087
1088                    self.current_right_batch = Some(right_batch);
1089
1090                    // Prepare right bitmap
1091                    if self.should_track_unmatched_right {
1092                        let zeroed_buf = BooleanBuffer::new_unset(right_batch_size);
1093                        self.current_right_batch_matched =
1094                            Some(BooleanArray::new(zeroed_buf, None));
1095                    }
1096
1097                    self.left_probe_idx = 0;
1098                    self.state = NLJState::ProbeRight;
1099                    ControlFlow::Continue(())
1100                }
1101                Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1102                None => {
1103                    // Right stream exhausted
1104                    self.state = NLJState::EmitLeftUnmatched;
1105                    ControlFlow::Continue(())
1106                }
1107            },
1108            Poll::Pending => ControlFlow::Break(Poll::Pending),
1109        }
1110    }
1111
1112    /// Handle ProbeRight state - process current probe batch
1113    fn handle_probe_right(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1114        // Return any completed batches first
1115        if let Some(poll) = self.maybe_flush_ready_batch() {
1116            return ControlFlow::Break(poll);
1117        }
1118
1119        // Process current probe state
1120        match self.process_probe_batch() {
1121            // State unchanged (ProbeRight)
1122            // Continue probing until we have done joining the
1123            // current right batch with all buffered left rows.
1124            Ok(true) => ControlFlow::Continue(()),
1125            // To next FetchRightState
1126            // We have finished joining
1127            // (cur_right_batch x buffered_left_batches)
1128            Ok(false) => {
1129                // Left exhausted, transition to FetchingRight
1130                self.left_probe_idx = 0;
1131
1132                // Selectivity Metric: Update total possibilities for the batch (left_rows * right_rows)
1133                // If memory-limited execution is implemented, this logic must be updated accordingly.
1134                if let (Ok(left_data), Some(right_batch)) =
1135                    (self.get_left_data(), self.current_right_batch.as_ref())
1136                {
1137                    let left_rows = left_data.batch().num_rows();
1138                    let right_rows = right_batch.num_rows();
1139                    self.metrics.selectivity.add_total(left_rows * right_rows);
1140                }
1141
1142                if self.should_track_unmatched_right {
1143                    debug_assert!(
1144                        self.current_right_batch_matched.is_some(),
1145                        "If it's required to track matched rows in the right input, the right bitmap must be present"
1146                    );
1147                    self.state = NLJState::EmitRightUnmatched;
1148                } else {
1149                    self.current_right_batch = None;
1150                    self.state = NLJState::FetchingRight;
1151                }
1152                ControlFlow::Continue(())
1153            }
1154            Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1155        }
1156    }
1157
1158    /// Handle EmitRightUnmatched state - emit unmatched right rows
1159    fn handle_emit_right_unmatched(
1160        &mut self,
1161    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1162        // Return any completed batches first
1163        if let Some(poll) = self.maybe_flush_ready_batch() {
1164            return ControlFlow::Break(poll);
1165        }
1166
1167        debug_assert!(
1168            self.current_right_batch_matched.is_some()
1169                && self.current_right_batch.is_some(),
1170            "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present"
1171        );
1172        // Construct the result batch for unmatched right rows using a utility function
1173        match self.process_right_unmatched() {
1174            Ok(Some(batch)) => {
1175                match self.output_buffer.push_batch(batch) {
1176                    Ok(()) => {
1177                        // Processed all in one pass
1178                        // cleared inside `process_right_unmatched`
1179                        debug_assert!(self.current_right_batch.is_none());
1180                        self.state = NLJState::FetchingRight;
1181                        ControlFlow::Continue(())
1182                    }
1183                    Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1184                }
1185            }
1186            Ok(None) => {
1187                // Processed all in one pass
1188                // cleared inside `process_right_unmatched`
1189                debug_assert!(self.current_right_batch.is_none());
1190                self.state = NLJState::FetchingRight;
1191                ControlFlow::Continue(())
1192            }
1193            Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1194        }
1195    }
1196
1197    /// Handle EmitLeftUnmatched state - emit unmatched left rows
1198    fn handle_emit_left_unmatched(
1199        &mut self,
1200    ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1201        // Return any completed batches first
1202        if let Some(poll) = self.maybe_flush_ready_batch() {
1203            return ControlFlow::Break(poll);
1204        }
1205
1206        // Process current unmatched state
1207        match self.process_left_unmatched() {
1208            // State unchanged (EmitLeftUnmatched)
1209            // Continue processing until we have processed all unmatched rows
1210            Ok(true) => ControlFlow::Continue(()),
1211            // To Done state
1212            // We have finished processing all unmatched rows
1213            Ok(false) => match self.output_buffer.finish_buffered_batch() {
1214                Ok(()) => {
1215                    self.state = NLJState::Done;
1216                    ControlFlow::Continue(())
1217                }
1218                Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1219            },
1220            Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1221        }
1222    }
1223
1224    /// Handle Done state - final state processing
1225    fn handle_done(&mut self) -> Poll<Option<Result<RecordBatch>>> {
1226        // Return any remaining completed batches before final termination
1227        if let Some(poll) = self.maybe_flush_ready_batch() {
1228            return poll;
1229        }
1230
1231        // HACK for the doc test in https://github.com/apache/datafusion/blob/main/datafusion/core/src/dataframe/mod.rs#L1265
1232        // If this operator directly return `Poll::Ready(None)`
1233        // for empty result, the final result will become an empty
1234        // batch with empty schema, however the expected result
1235        // should be with the expected schema for this operator
1236        if !self.handled_empty_output {
1237            let zero_count = Count::new();
1238            if *self.metrics.join_metrics.baseline.output_rows() == zero_count {
1239                let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema));
1240                self.handled_empty_output = true;
1241                return Poll::Ready(Some(Ok(empty_batch)));
1242            }
1243        }
1244
1245        Poll::Ready(None)
1246    }
1247
1248    // ==== Core logic handling for each state ====
1249
1250    /// Returns bool to indicate should it continue probing
1251    /// true -> continue in the same ProbeRight state
1252    /// false -> It has done with the (buffered_left x cur_right_batch), go to
1253    /// next state (ProbeRight)
1254    fn process_probe_batch(&mut self) -> Result<bool> {
1255        let left_data = Arc::clone(self.get_left_data()?);
1256        let right_batch = self
1257            .current_right_batch
1258            .as_ref()
1259            .ok_or_else(|| internal_datafusion_err!("Right batch should be available"))?
1260            .clone();
1261
1262        // stop probing, the caller will go to the next state
1263        if self.left_probe_idx >= left_data.batch().num_rows() {
1264            return Ok(false);
1265        }
1266
1267        // ========
1268        // Join (l_row x right_batch)
1269        // and push the result into output_buffer
1270        // ========
1271
1272        let l_idx = self.left_probe_idx;
1273        let join_batch =
1274            self.process_single_left_row_join(&left_data, &right_batch, l_idx)?;
1275
1276        if let Some(batch) = join_batch {
1277            self.output_buffer.push_batch(batch)?;
1278        }
1279
1280        // ==== Prepare for the next iteration ====
1281
1282        // Advance left cursor
1283        self.left_probe_idx += 1;
1284
1285        // Return true to continue probing
1286        Ok(true)
1287    }
1288
1289    /// Process a single left row join with the current right batch.
1290    /// Returns a RecordBatch containing the join results (None if empty)
1291    fn process_single_left_row_join(
1292        &mut self,
1293        left_data: &JoinLeftData,
1294        right_batch: &RecordBatch,
1295        l_index: usize,
1296    ) -> Result<Option<RecordBatch>> {
1297        let right_row_count = right_batch.num_rows();
1298        if right_row_count == 0 {
1299            return Ok(None);
1300        }
1301
1302        let cur_right_bitmap = if let Some(filter) = &self.join_filter {
1303            apply_filter_to_row_join_batch(
1304                left_data.batch(),
1305                l_index,
1306                right_batch,
1307                filter,
1308            )?
1309        } else {
1310            BooleanArray::from(vec![true; right_row_count])
1311        };
1312
1313        self.update_matched_bitmap(l_index, &cur_right_bitmap)?;
1314
1315        // For the following join types: here we only have to set the left/right
1316        // bitmap, and no need to output result
1317        if matches!(
1318            self.join_type,
1319            JoinType::LeftAnti
1320                | JoinType::LeftSemi
1321                | JoinType::LeftMark
1322                | JoinType::RightAnti
1323                | JoinType::RightMark
1324                | JoinType::RightSemi
1325        ) {
1326            return Ok(None);
1327        }
1328
1329        if cur_right_bitmap.true_count() == 0 {
1330            // If none of the pairs has passed the join predicate/filter
1331            Ok(None)
1332        } else {
1333            // Use the optimized approach similar to build_intermediate_batch_for_single_left_row
1334            let join_batch = build_row_join_batch(
1335                &self.output_schema,
1336                left_data.batch(),
1337                l_index,
1338                right_batch,
1339                Some(cur_right_bitmap),
1340                &self.column_indices,
1341                JoinSide::Left,
1342            )?;
1343            Ok(join_batch)
1344        }
1345    }
1346
1347    /// Returns bool to indicate should it continue processing unmatched rows
1348    /// true -> continue in the same EmitLeftUnmatched state
1349    /// false -> next state (Done)
1350    fn process_left_unmatched(&mut self) -> Result<bool> {
1351        let left_data = self.get_left_data()?;
1352        let left_batch = left_data.batch();
1353
1354        // ========
1355        // Check early return conditions
1356        // ========
1357
1358        // Early return if join type can't have unmatched rows
1359        let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
1360        // Early return if another thread is already processing unmatched rows
1361        let handled_by_other_partition =
1362            self.left_emit_idx == 0 && !left_data.report_probe_completed();
1363        // Stop processing unmatched rows, the caller will go to the next state
1364        let finished = self.left_emit_idx >= left_batch.num_rows();
1365
1366        if join_type_no_produce_left || handled_by_other_partition || finished {
1367            return Ok(false);
1368        }
1369
1370        // ========
1371        // Process unmatched rows and push the result into output_buffer
1372        // Each time, the number to process is up to batch size
1373        // ========
1374        let start_idx = self.left_emit_idx;
1375        let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
1376
1377        if let Some(batch) =
1378            self.process_left_unmatched_range(left_data, start_idx, end_idx)?
1379        {
1380            self.output_buffer.push_batch(batch)?;
1381        }
1382
1383        // ==== Prepare for the next iteration ====
1384        self.left_emit_idx = end_idx;
1385
1386        // Return true to continue processing unmatched rows
1387        Ok(true)
1388    }
1389
1390    /// Process unmatched rows from the left data within the specified range.
1391    /// Returns a RecordBatch containing the unmatched rows (None if empty).
1392    ///
1393    /// # Arguments
1394    /// * `left_data` - The left side data containing the batch and bitmap
1395    /// * `start_idx` - Start index (inclusive) of the range to process
1396    /// * `end_idx` - End index (exclusive) of the range to process
1397    ///
1398    /// # Safety
1399    /// The caller is responsible for ensuring that `start_idx` and `end_idx` are
1400    /// within valid bounds of the left batch. This function does not perform
1401    /// bounds checking.
1402    fn process_left_unmatched_range(
1403        &self,
1404        left_data: &JoinLeftData,
1405        start_idx: usize,
1406        end_idx: usize,
1407    ) -> Result<Option<RecordBatch>> {
1408        if start_idx == end_idx {
1409            return Ok(None);
1410        }
1411
1412        // Slice both left batch, and bitmap to range [start_idx, end_idx)
1413        // The range is bit index (not byte)
1414        let left_batch = left_data.batch();
1415        let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx);
1416
1417        // Can this be more efficient?
1418        let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx);
1419        bitmap_sliced.append_n(end_idx - start_idx, false);
1420        let bitmap = left_data.bitmap().lock();
1421        for i in start_idx..end_idx {
1422            assert!(
1423                i - start_idx < bitmap_sliced.capacity(),
1424                "DBG: {start_idx}, {end_idx}"
1425            );
1426            bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i));
1427        }
1428        let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None);
1429
1430        build_unmatched_batch(
1431            Arc::clone(&self.output_schema),
1432            &left_batch_sliced,
1433            bitmap_sliced,
1434            self.right_data.schema(),
1435            &self.column_indices,
1436            self.join_type,
1437            JoinSide::Left,
1438        )
1439    }
1440
1441    /// Process unmatched rows from the current right batch and reset the bitmap.
1442    /// Returns a RecordBatch containing the unmatched right rows (None if empty).
1443    fn process_right_unmatched(&mut self) -> Result<Option<RecordBatch>> {
1444        // ==== Take current right batch and its bitmap ====
1445        let right_batch_bitmap: BooleanArray =
1446            std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| {
1447                internal_datafusion_err!("right bitmap should be available")
1448            })?;
1449
1450        let right_batch = self.current_right_batch.take();
1451        let cur_right_batch = unwrap_or_internal_err!(right_batch);
1452
1453        let left_data = self.get_left_data()?;
1454        let left_schema = left_data.batch().schema();
1455
1456        let res = build_unmatched_batch(
1457            Arc::clone(&self.output_schema),
1458            &cur_right_batch,
1459            right_batch_bitmap,
1460            left_schema,
1461            &self.column_indices,
1462            self.join_type,
1463            JoinSide::Right,
1464        );
1465
1466        // ==== Clean-up ====
1467        self.current_right_batch_matched = None;
1468
1469        res
1470    }
1471
1472    // ==== Utilities ====
1473
1474    /// Get the build-side data of the left input, errors if it's None
1475    fn get_left_data(&self) -> Result<&Arc<JoinLeftData>> {
1476        self.buffered_left_data
1477            .as_ref()
1478            .ok_or_else(|| internal_datafusion_err!("LeftData should be available"))
1479    }
1480
1481    /// Flush the `output_buffer` if there are batches ready to output
1482    /// None if no result batch ready.
1483    fn maybe_flush_ready_batch(&mut self) -> Option<Poll<Option<Result<RecordBatch>>>> {
1484        if self.output_buffer.has_completed_batch() {
1485            if let Some(batch) = self.output_buffer.next_completed_batch() {
1486                // HACK: this is not part of `BaselineMetrics` yet, so update it
1487                // manually
1488                self.metrics.join_metrics.output_batches.add(1);
1489
1490                // Update output rows for selectivity metric
1491                let output_rows = batch.num_rows();
1492                self.metrics.selectivity.add_part(output_rows);
1493
1494                return Some(Poll::Ready(Some(Ok(batch))));
1495            }
1496        }
1497
1498        None
1499    }
1500
1501    /// After joining (l_index@left_buffer x current_right_batch), it will result
1502    /// in a bitmap (the same length as current_right_batch) as the join match
1503    /// result. Use this bitmap to update the global bitmap, for special join
1504    /// types like full joins.
1505    ///
1506    /// Example:
1507    /// After joining l_index=1 (1-indexed row in the left buffer), and the
1508    /// current right batch with 3 elements, this function will be called with
1509    /// arguments: l_index = 1, r_matched = [false, false, true]
1510    /// - If the join type is FullJoin, the 1-index in the left bitmap will be
1511    ///   set to true, and also the right bitmap will be bitwise-ORed with the
1512    ///   input r_matched bitmap.
1513    /// - For join types that don't require output unmatched rows, this
1514    ///   function can be a no-op. For inner joins, this function is a no-op; for left
1515    ///   joins, only the left bitmap may be updated.
1516    fn update_matched_bitmap(
1517        &mut self,
1518        l_index: usize,
1519        r_matched_bitmap: &BooleanArray,
1520    ) -> Result<()> {
1521        let left_data = self.get_left_data()?;
1522
1523        // number of successfully joined pairs from (l_index x cur_right_batch)
1524        let joined_len = r_matched_bitmap.true_count();
1525
1526        // 1. Maybe update the left bitmap
1527        if need_produce_result_in_final(self.join_type) && (joined_len > 0) {
1528            let mut bitmap = left_data.bitmap().lock();
1529            bitmap.set_bit(l_index, true);
1530        }
1531
1532        // 2. Maybe updateh the right bitmap
1533        if self.should_track_unmatched_right {
1534            debug_assert!(self.current_right_batch_matched.is_some());
1535            // after bit-wise or, it will be put back
1536            let right_bitmap = std::mem::take(&mut self.current_right_batch_matched)
1537                .ok_or_else(|| {
1538                    internal_datafusion_err!("right batch's bitmap should be present")
1539                })?;
1540            let (buf, nulls) = right_bitmap.into_parts();
1541            debug_assert!(nulls.is_none());
1542            let updated_right_bitmap = buf.bitor(r_matched_bitmap.values());
1543
1544            self.current_right_batch_matched =
1545                Some(BooleanArray::new(updated_right_bitmap, None));
1546        }
1547
1548        Ok(())
1549    }
1550}
1551
1552// ==== Utilities ====
1553
1554/// Apply the join filter between:
1555/// (l_index th row in left buffer) x (right batch)
1556/// Returns a bitmap, with successfully joined indices set to true
1557fn apply_filter_to_row_join_batch(
1558    left_batch: &RecordBatch,
1559    l_index: usize,
1560    right_batch: &RecordBatch,
1561    filter: &JoinFilter,
1562) -> Result<BooleanArray> {
1563    debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0);
1564
1565    let intermediate_batch = if filter.schema.fields().is_empty() {
1566        // If filter is constant (e.g. literal `true`), empty batch can be used
1567        // in the later filter step.
1568        create_record_batch_with_empty_schema(
1569            Arc::new((*filter.schema).clone()),
1570            right_batch.num_rows(),
1571        )?
1572    } else {
1573        build_row_join_batch(
1574            &filter.schema,
1575            left_batch,
1576            l_index,
1577            right_batch,
1578            None,
1579            &filter.column_indices,
1580            JoinSide::Left,
1581        )?
1582        .ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))?
1583    };
1584
1585    let filter_result = filter
1586        .expression()
1587        .evaluate(&intermediate_batch)?
1588        .into_array(intermediate_batch.num_rows())?;
1589    let filter_arr = as_boolean_array(&filter_result)?;
1590
1591    // [Caution] This step has previously introduced bugs
1592    // The filter result is NOT a bitmap; it contains true/false/null values.
1593    // For example, 1 < NULL is evaluated to NULL. Therefore, we must combine (AND)
1594    // the boolean array with its null bitmap to construct a unified bitmap.
1595    let (is_filtered, nulls) = filter_arr.clone().into_parts();
1596    let bitmap_combined = match nulls {
1597        Some(nulls) => {
1598            let combined = nulls.inner() & &is_filtered;
1599            BooleanArray::new(combined, None)
1600        }
1601        None => BooleanArray::new(is_filtered, None),
1602    };
1603
1604    Ok(bitmap_combined)
1605}
1606
1607/// This function performs the following steps:
1608/// 1. Apply filter to probe-side batch
1609/// 2. Broadcast the left row (build_side_batch\[build_side_index\]) to the
1610///    filtered probe-side batch
1611/// 3. Concat them together according to `col_indices`, and return the result
1612///    (None if the result is empty)
1613///
1614/// Example:
1615/// build_side_batch:
1616/// a
1617/// ----
1618/// 1
1619/// 2
1620/// 3
1621///
1622/// # 0 index element in the build_side_batch (that is `1`) will be used
1623/// build_side_index: 0
1624///
1625/// probe_side_batch:
1626/// b
1627/// ----
1628/// 10
1629/// 20
1630/// 30
1631/// 40
1632///
1633/// # After applying it, only index 1 and 3 elements in probe_side_batch will be
1634/// # kept
1635/// probe_side_filter:
1636/// false
1637/// true
1638/// false
1639/// true
1640///
1641///
1642/// # Projections to the build/probe side batch, to construct the output batch
1643/// col_indices:
1644/// [(left, 0), (right, 0)]
1645///
1646/// build_side: left
1647///
1648/// ====
1649/// Result batch:
1650/// a b
1651/// ----
1652/// 1 20
1653/// 1 40
1654fn build_row_join_batch(
1655    output_schema: &Schema,
1656    build_side_batch: &RecordBatch,
1657    build_side_index: usize,
1658    probe_side_batch: &RecordBatch,
1659    probe_side_filter: Option<BooleanArray>,
1660    // See [`NLJStream`] struct's `column_indices` field for more detail
1661    col_indices: &[ColumnIndex],
1662    // If the build side is left or right, used to interpret the side information
1663    // in `col_indices`
1664    build_side: JoinSide,
1665) -> Result<Option<RecordBatch>> {
1666    debug_assert!(build_side != JoinSide::None);
1667
1668    // TODO(perf): since the output might be projection of right batch, this
1669    // filtering step is more efficient to be done inside the column_index loop
1670    let filtered_probe_batch = if let Some(filter) = probe_side_filter {
1671        &filter_record_batch(probe_side_batch, &filter)?
1672    } else {
1673        probe_side_batch
1674    };
1675
1676    if filtered_probe_batch.num_rows() == 0 {
1677        return Ok(None);
1678    }
1679
1680    // Edge case: downstream operator does not require any columns from this NLJ,
1681    // so allow an empty projection.
1682    // Example:
1683    //  SELECT DISTINCT 32 AS col2
1684    //  FROM tab0 AS cor0
1685    //  LEFT OUTER JOIN tab2 AS cor1
1686    //  ON ( NULL ) IS NULL;
1687    if output_schema.fields.is_empty() {
1688        return Ok(Some(create_record_batch_with_empty_schema(
1689            Arc::new(output_schema.clone()),
1690            filtered_probe_batch.num_rows(),
1691        )?));
1692    }
1693
1694    let mut columns: Vec<Arc<dyn Array>> =
1695        Vec::with_capacity(output_schema.fields().len());
1696
1697    for column_index in col_indices {
1698        let array = if column_index.side == build_side {
1699            // Broadcast the single build-side row to match the filtered
1700            // probe-side batch length
1701            let original_left_array = build_side_batch.column(column_index.index);
1702            // Avoid using `ScalarValue::to_array_of_size()` for `List(Utf8View)` to avoid
1703            // deep copies for buffers inside `Utf8View` array. See below for details.
1704            // https://github.com/apache/datafusion/issues/18159
1705            //
1706            // In other cases, `to_array_of_size()` is faster.
1707            match original_left_array.data_type() {
1708                DataType::List(field) | DataType::LargeList(field)
1709                    if field.data_type() == &DataType::Utf8View =>
1710                {
1711                    let indices_iter = std::iter::repeat_n(
1712                        build_side_index as u64,
1713                        filtered_probe_batch.num_rows(),
1714                    );
1715                    let indices_array = UInt64Array::from_iter_values(indices_iter);
1716                    take(original_left_array.as_ref(), &indices_array, None)?
1717                }
1718                _ => {
1719                    let scalar_value = ScalarValue::try_from_array(
1720                        original_left_array.as_ref(),
1721                        build_side_index,
1722                    )?;
1723                    scalar_value.to_array_of_size(filtered_probe_batch.num_rows())?
1724                }
1725            }
1726        } else {
1727            // Take the filtered probe-side column using compute::take
1728            Arc::clone(filtered_probe_batch.column(column_index.index))
1729        };
1730
1731        columns.push(array);
1732    }
1733
1734    Ok(Some(RecordBatch::try_new(
1735        Arc::new(output_schema.clone()),
1736        columns,
1737    )?))
1738}
1739
1740/// Special case for `PlaceHolderRowExec`
1741/// Minimal example:  SELECT 1 WHERE EXISTS (SELECT 1);
1742//
1743/// # Return
1744/// If Some, that's the result batch
1745/// If None, it's not for this special case. Continue execution.
1746fn build_unmatched_batch_empty_schema(
1747    output_schema: SchemaRef,
1748    batch_bitmap: &BooleanArray,
1749    // For left/right/full joins, it needs to fill nulls for another side
1750    join_type: JoinType,
1751) -> Result<Option<RecordBatch>> {
1752    let result_size = match join_type {
1753        JoinType::Left
1754        | JoinType::Right
1755        | JoinType::Full
1756        | JoinType::LeftAnti
1757        | JoinType::RightAnti => batch_bitmap.false_count(),
1758        JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(),
1759        JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(),
1760        _ => unreachable!(),
1761    };
1762
1763    if output_schema.fields().is_empty() {
1764        Ok(Some(create_record_batch_with_empty_schema(
1765            Arc::clone(&output_schema),
1766            result_size,
1767        )?))
1768    } else {
1769        Ok(None)
1770    }
1771}
1772
1773/// Creates an empty RecordBatch with a specific row count.
1774/// This is useful for cases where we need a batch with the correct schema and row count
1775/// but no actual data columns (e.g., for constant filters).
1776fn create_record_batch_with_empty_schema(
1777    schema: SchemaRef,
1778    row_count: usize,
1779) -> Result<RecordBatch> {
1780    let options = RecordBatchOptions::new()
1781        .with_match_field_names(true)
1782        .with_row_count(Some(row_count));
1783
1784    RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| {
1785        internal_datafusion_err!("Failed to create empty record batch: {}", e)
1786    })
1787}
1788
1789/// # Example:
1790/// batch:
1791/// a
1792/// ----
1793/// 1
1794/// 2
1795/// 3
1796///
1797/// batch_bitmap:
1798/// ----
1799/// false
1800/// true
1801/// false
1802///
1803/// another_side_schema:
1804/// [(b, bool), (c, int32)]
1805///
1806/// join_type: JoinType::Left
1807///
1808/// col_indices: ...(please refer to the comment in `NLJStream::column_indices``)
1809///
1810/// batch_side: right
1811///
1812/// # Walkthrough:
1813///
1814/// This executor is performing a right join, and the currently processed right
1815/// batch is as above. After joining it with all buffered left rows, the joined
1816/// entries are marked by the `batch_bitmap`.
1817/// This method will keep the unmatched indices on the batch side (right), and pad
1818/// the left side with nulls. The result would be:
1819///
1820/// b          c           a
1821/// ------------------------
1822/// Null(bool) Null(Int32) 1
1823/// Null(bool) Null(Int32) 3
1824fn build_unmatched_batch(
1825    output_schema: SchemaRef,
1826    batch: &RecordBatch,
1827    batch_bitmap: BooleanArray,
1828    // For left/right/full joins, it needs to fill nulls for another side
1829    another_side_schema: SchemaRef,
1830    col_indices: &[ColumnIndex],
1831    join_type: JoinType,
1832    batch_side: JoinSide,
1833) -> Result<Option<RecordBatch>> {
1834    // Should not call it for inner joins
1835    debug_assert_ne!(join_type, JoinType::Inner);
1836    debug_assert_ne!(batch_side, JoinSide::None);
1837
1838    // Handle special case (see function comment)
1839    if let Some(batch) = build_unmatched_batch_empty_schema(
1840        Arc::clone(&output_schema),
1841        &batch_bitmap,
1842        join_type,
1843    )? {
1844        return Ok(Some(batch));
1845    }
1846
1847    match join_type {
1848        JoinType::Full | JoinType::Right | JoinType::Left => {
1849            if join_type == JoinType::Right {
1850                debug_assert_eq!(batch_side, JoinSide::Right);
1851            }
1852            if join_type == JoinType::Left {
1853                debug_assert_eq!(batch_side, JoinSide::Left);
1854            }
1855
1856            // 1. Filter the batch with *flipped* bitmap
1857            // 2. Fill left side with nulls
1858            let flipped_bitmap = not(&batch_bitmap)?;
1859
1860            // create a recordbatch, with left_schema, of only one row of all nulls
1861            let left_null_columns: Vec<Arc<dyn Array>> = another_side_schema
1862                .fields()
1863                .iter()
1864                .map(|field| new_null_array(field.data_type(), 1))
1865                .collect();
1866
1867            // Hack: If the left schema is not nullable, the full join result
1868            // might contain null, this is only a temporary batch to construct
1869            // such full join result.
1870            let nullable_left_schema = Arc::new(Schema::new(
1871                another_side_schema
1872                    .fields()
1873                    .iter()
1874                    .map(|field| {
1875                        (**field).clone().with_nullable(true)
1876                    })
1877                    .collect::<Vec<_>>(),
1878            ));
1879            let left_null_batch = if nullable_left_schema.fields.is_empty() {
1880                // Left input can be an empty relation, in this case left relation
1881                // won't be used to construct the result batch (i.e. not in `col_indices`)
1882                create_record_batch_with_empty_schema(nullable_left_schema, 0)?
1883            } else {
1884                RecordBatch::try_new(nullable_left_schema, left_null_columns)?
1885            };
1886
1887            debug_assert_ne!(batch_side, JoinSide::None);
1888            let opposite_side = batch_side.negate();
1889
1890            build_row_join_batch(&output_schema, &left_null_batch, 0, batch, Some(flipped_bitmap), col_indices, opposite_side)
1891
1892        },
1893        JoinType::RightSemi | JoinType::RightAnti | JoinType::LeftSemi | JoinType::LeftAnti => {
1894            if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) {
1895                debug_assert_eq!(batch_side, JoinSide::Right);
1896            }
1897            if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
1898                debug_assert_eq!(batch_side, JoinSide::Left);
1899            }
1900
1901            let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi) {
1902                batch_bitmap.clone()
1903            } else {
1904                not(&batch_bitmap)?
1905            };
1906
1907            if bitmap.true_count() == 0 {
1908                return Ok(None);
1909            }
1910
1911            let mut columns: Vec<Arc<dyn Array>> =
1912                Vec::with_capacity(output_schema.fields().len());
1913
1914            for column_index in col_indices {
1915                debug_assert!(column_index.side == batch_side);
1916
1917                let col = batch.column(column_index.index);
1918                let filtered_col = filter(col, &bitmap)?;
1919
1920                columns.push(filtered_col);
1921            }
1922
1923            Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1924        },
1925        JoinType::RightMark | JoinType::LeftMark => {
1926            if join_type == JoinType::RightMark {
1927                debug_assert_eq!(batch_side, JoinSide::Right);
1928            }
1929            if join_type == JoinType::LeftMark {
1930                debug_assert_eq!(batch_side, JoinSide::Left);
1931            }
1932
1933            let mut columns: Vec<Arc<dyn Array>> =
1934                Vec::with_capacity(output_schema.fields().len());
1935
1936            // Hack to deal with the borrow checker
1937            let mut right_batch_bitmap_opt = Some(batch_bitmap);
1938
1939            for column_index in col_indices {
1940                if column_index.side == batch_side {
1941                    let col = batch.column(column_index.index);
1942
1943                    columns.push(Arc::clone(col));
1944                } else if column_index.side == JoinSide::None {
1945                    let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt);
1946                    match right_batch_bitmap {
1947                        Some(right_batch_bitmap) => {columns.push(Arc::new(right_batch_bitmap))},
1948                        None => unreachable!("Should only be one mark column"),
1949                    }
1950                } else {
1951                    return internal_err!("Not possible to have this join side for RightMark join");
1952                }
1953            }
1954
1955            Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1956        }
1957        _ => internal_err!("If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"),
1958    }
1959}
1960
1961#[cfg(test)]
1962pub(crate) mod tests {
1963    use super::*;
1964    use crate::test::{assert_join_metrics, TestMemoryExec};
1965    use crate::{
1966        common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
1967    };
1968
1969    use arrow::compute::SortOptions;
1970    use arrow::datatypes::{DataType, Field};
1971    use datafusion_common::test_util::batches_to_sort_string;
1972    use datafusion_common::{assert_contains, ScalarValue};
1973    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1974    use datafusion_expr::Operator;
1975    use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
1976    use datafusion_physical_expr::{Partitioning, PhysicalExpr};
1977    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1978
1979    use insta::allow_duplicates;
1980    use insta::assert_snapshot;
1981    use rstest::rstest;
1982
1983    fn build_table(
1984        a: (&str, &Vec<i32>),
1985        b: (&str, &Vec<i32>),
1986        c: (&str, &Vec<i32>),
1987        batch_size: Option<usize>,
1988        sorted_column_names: Vec<&str>,
1989    ) -> Arc<dyn ExecutionPlan> {
1990        let batch = build_table_i32(a, b, c);
1991        let schema = batch.schema();
1992
1993        let batches = if let Some(batch_size) = batch_size {
1994            let num_batches = batch.num_rows().div_ceil(batch_size);
1995            (0..num_batches)
1996                .map(|i| {
1997                    let start = i * batch_size;
1998                    let remaining_rows = batch.num_rows() - start;
1999                    batch.slice(start, batch_size.min(remaining_rows))
2000                })
2001                .collect::<Vec<_>>()
2002        } else {
2003            vec![batch]
2004        };
2005
2006        let mut sort_info = vec![];
2007        for name in sorted_column_names {
2008            let index = schema.index_of(name).unwrap();
2009            let sort_expr = PhysicalSortExpr::new(
2010                Arc::new(Column::new(name, index)),
2011                SortOptions::new(false, false),
2012            );
2013            sort_info.push(sort_expr);
2014        }
2015        let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap();
2016        if let Some(ordering) = LexOrdering::new(sort_info) {
2017            source = source.try_with_sort_information(vec![ordering]).unwrap();
2018        }
2019
2020        Arc::new(TestMemoryExec::update_cache(Arc::new(source)))
2021    }
2022
2023    fn build_left_table() -> Arc<dyn ExecutionPlan> {
2024        build_table(
2025            ("a1", &vec![5, 9, 11]),
2026            ("b1", &vec![5, 8, 8]),
2027            ("c1", &vec![50, 90, 110]),
2028            None,
2029            Vec::new(),
2030        )
2031    }
2032
2033    fn build_right_table() -> Arc<dyn ExecutionPlan> {
2034        build_table(
2035            ("a2", &vec![12, 2, 10]),
2036            ("b2", &vec![10, 2, 10]),
2037            ("c2", &vec![40, 80, 100]),
2038            None,
2039            Vec::new(),
2040        )
2041    }
2042
2043    fn prepare_join_filter() -> JoinFilter {
2044        let column_indices = vec![
2045            ColumnIndex {
2046                index: 1,
2047                side: JoinSide::Left,
2048            },
2049            ColumnIndex {
2050                index: 1,
2051                side: JoinSide::Right,
2052            },
2053        ];
2054        let intermediate_schema = Schema::new(vec![
2055            Field::new("x", DataType::Int32, true),
2056            Field::new("x", DataType::Int32, true),
2057        ]);
2058        // left.b1!=8
2059        let left_filter = Arc::new(BinaryExpr::new(
2060            Arc::new(Column::new("x", 0)),
2061            Operator::NotEq,
2062            Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
2063        )) as Arc<dyn PhysicalExpr>;
2064        // right.b2!=10
2065        let right_filter = Arc::new(BinaryExpr::new(
2066            Arc::new(Column::new("x", 1)),
2067            Operator::NotEq,
2068            Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2069        )) as Arc<dyn PhysicalExpr>;
2070        // filter = left.b1!=8 and right.b2!=10
2071        // after filter:
2072        // left table:
2073        // ("a1", &vec![5]),
2074        // ("b1", &vec![5]),
2075        // ("c1", &vec![50]),
2076        // right table:
2077        // ("a2", &vec![12, 2]),
2078        // ("b2", &vec![10, 2]),
2079        // ("c2", &vec![40, 80]),
2080        let filter_expression =
2081            Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
2082                as Arc<dyn PhysicalExpr>;
2083
2084        JoinFilter::new(
2085            filter_expression,
2086            column_indices,
2087            Arc::new(intermediate_schema),
2088        )
2089    }
2090
2091    pub(crate) async fn multi_partitioned_join_collect(
2092        left: Arc<dyn ExecutionPlan>,
2093        right: Arc<dyn ExecutionPlan>,
2094        join_type: &JoinType,
2095        join_filter: Option<JoinFilter>,
2096        context: Arc<TaskContext>,
2097    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
2098        let partition_count = 4;
2099
2100        // Redistributing right input
2101        let right = Arc::new(RepartitionExec::try_new(
2102            right,
2103            Partitioning::RoundRobinBatch(partition_count),
2104        )?) as Arc<dyn ExecutionPlan>;
2105
2106        // Use the required distribution for nested loop join to test partition data
2107        let nested_loop_join =
2108            NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
2109        let columns = columns(&nested_loop_join.schema());
2110        let mut batches = vec![];
2111        for i in 0..partition_count {
2112            let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
2113            let more_batches = common::collect(stream).await?;
2114            batches.extend(
2115                more_batches
2116                    .into_iter()
2117                    .inspect(|b| {
2118                        assert!(b.num_rows() <= context.session_config().batch_size())
2119                    })
2120                    .filter(|b| b.num_rows() > 0)
2121                    .collect::<Vec<_>>(),
2122            );
2123        }
2124
2125        let metrics = nested_loop_join.metrics().unwrap();
2126
2127        Ok((columns, batches, metrics))
2128    }
2129
2130    fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
2131        let base = TaskContext::default();
2132        // limit max size of intermediate batch used in nlj to 1
2133        let cfg = base.session_config().clone().with_batch_size(batch_size);
2134        Arc::new(base.with_session_config(cfg))
2135    }
2136
2137    #[rstest]
2138    #[tokio::test]
2139    async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2140        let task_ctx = new_task_ctx(batch_size);
2141        dbg!(&batch_size);
2142        let left = build_left_table();
2143        let right = build_right_table();
2144        let filter = prepare_join_filter();
2145        let (columns, batches, metrics) = multi_partitioned_join_collect(
2146            left,
2147            right,
2148            &JoinType::Inner,
2149            Some(filter),
2150            task_ctx,
2151        )
2152        .await?;
2153
2154        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2155        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2156            +----+----+----+----+----+----+
2157            | a1 | b1 | c1 | a2 | b2 | c2 |
2158            +----+----+----+----+----+----+
2159            | 5  | 5  | 50 | 2  | 2  | 80 |
2160            +----+----+----+----+----+----+
2161            "#));
2162
2163        assert_join_metrics!(metrics, 1);
2164
2165        Ok(())
2166    }
2167
2168    #[rstest]
2169    #[tokio::test]
2170    async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2171        let task_ctx = new_task_ctx(batch_size);
2172        let left = build_left_table();
2173        let right = build_right_table();
2174
2175        let filter = prepare_join_filter();
2176        let (columns, batches, metrics) = multi_partitioned_join_collect(
2177            left,
2178            right,
2179            &JoinType::Left,
2180            Some(filter),
2181            task_ctx,
2182        )
2183        .await?;
2184        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2185        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2186            +----+----+-----+----+----+----+
2187            | a1 | b1 | c1  | a2 | b2 | c2 |
2188            +----+----+-----+----+----+----+
2189            | 11 | 8  | 110 |    |    |    |
2190            | 5  | 5  | 50  | 2  | 2  | 80 |
2191            | 9  | 8  | 90  |    |    |    |
2192            +----+----+-----+----+----+----+
2193            "#));
2194
2195        assert_join_metrics!(metrics, 3);
2196
2197        Ok(())
2198    }
2199
2200    #[rstest]
2201    #[tokio::test]
2202    async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2203        let task_ctx = new_task_ctx(batch_size);
2204        let left = build_left_table();
2205        let right = build_right_table();
2206
2207        let filter = prepare_join_filter();
2208        let (columns, batches, metrics) = multi_partitioned_join_collect(
2209            left,
2210            right,
2211            &JoinType::Right,
2212            Some(filter),
2213            task_ctx,
2214        )
2215        .await?;
2216        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2217        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2218            +----+----+----+----+----+-----+
2219            | a1 | b1 | c1 | a2 | b2 | c2  |
2220            +----+----+----+----+----+-----+
2221            |    |    |    | 10 | 10 | 100 |
2222            |    |    |    | 12 | 10 | 40  |
2223            | 5  | 5  | 50 | 2  | 2  | 80  |
2224            +----+----+----+----+----+-----+
2225            "#));
2226
2227        assert_join_metrics!(metrics, 3);
2228
2229        Ok(())
2230    }
2231
2232    #[rstest]
2233    #[tokio::test]
2234    async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2235        let task_ctx = new_task_ctx(batch_size);
2236        let left = build_left_table();
2237        let right = build_right_table();
2238
2239        let filter = prepare_join_filter();
2240        let (columns, batches, metrics) = multi_partitioned_join_collect(
2241            left,
2242            right,
2243            &JoinType::Full,
2244            Some(filter),
2245            task_ctx,
2246        )
2247        .await?;
2248        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2249        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2250            +----+----+-----+----+----+-----+
2251            | a1 | b1 | c1  | a2 | b2 | c2  |
2252            +----+----+-----+----+----+-----+
2253            |    |    |     | 10 | 10 | 100 |
2254            |    |    |     | 12 | 10 | 40  |
2255            | 11 | 8  | 110 |    |    |     |
2256            | 5  | 5  | 50  | 2  | 2  | 80  |
2257            | 9  | 8  | 90  |    |    |     |
2258            +----+----+-----+----+----+-----+
2259            "#));
2260
2261        assert_join_metrics!(metrics, 5);
2262
2263        Ok(())
2264    }
2265
2266    #[rstest]
2267    #[tokio::test]
2268    async fn join_left_semi_with_filter(
2269        #[values(1, 2, 16)] batch_size: usize,
2270    ) -> Result<()> {
2271        let task_ctx = new_task_ctx(batch_size);
2272        let left = build_left_table();
2273        let right = build_right_table();
2274
2275        let filter = prepare_join_filter();
2276        let (columns, batches, metrics) = multi_partitioned_join_collect(
2277            left,
2278            right,
2279            &JoinType::LeftSemi,
2280            Some(filter),
2281            task_ctx,
2282        )
2283        .await?;
2284        assert_eq!(columns, vec!["a1", "b1", "c1"]);
2285        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2286            +----+----+----+
2287            | a1 | b1 | c1 |
2288            +----+----+----+
2289            | 5  | 5  | 50 |
2290            +----+----+----+
2291            "#));
2292
2293        assert_join_metrics!(metrics, 1);
2294
2295        Ok(())
2296    }
2297
2298    #[rstest]
2299    #[tokio::test]
2300    async fn join_left_anti_with_filter(
2301        #[values(1, 2, 16)] batch_size: usize,
2302    ) -> Result<()> {
2303        let task_ctx = new_task_ctx(batch_size);
2304        let left = build_left_table();
2305        let right = build_right_table();
2306
2307        let filter = prepare_join_filter();
2308        let (columns, batches, metrics) = multi_partitioned_join_collect(
2309            left,
2310            right,
2311            &JoinType::LeftAnti,
2312            Some(filter),
2313            task_ctx,
2314        )
2315        .await?;
2316        assert_eq!(columns, vec!["a1", "b1", "c1"]);
2317        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2318            +----+----+-----+
2319            | a1 | b1 | c1  |
2320            +----+----+-----+
2321            | 11 | 8  | 110 |
2322            | 9  | 8  | 90  |
2323            +----+----+-----+
2324            "#));
2325
2326        assert_join_metrics!(metrics, 2);
2327
2328        Ok(())
2329    }
2330
2331    #[tokio::test]
2332    async fn join_has_correct_stats() -> Result<()> {
2333        let left = build_left_table();
2334        let right = build_right_table();
2335        let nested_loop_join = NestedLoopJoinExec::try_new(
2336            left,
2337            right,
2338            None,
2339            &JoinType::Left,
2340            Some(vec![1, 2]),
2341        )?;
2342        let stats = nested_loop_join.partition_statistics(None)?;
2343        assert_eq!(
2344            nested_loop_join.schema().fields().len(),
2345            stats.column_statistics.len(),
2346        );
2347        assert_eq!(2, stats.column_statistics.len());
2348        Ok(())
2349    }
2350
2351    #[rstest]
2352    #[tokio::test]
2353    async fn join_right_semi_with_filter(
2354        #[values(1, 2, 16)] batch_size: usize,
2355    ) -> Result<()> {
2356        let task_ctx = new_task_ctx(batch_size);
2357        let left = build_left_table();
2358        let right = build_right_table();
2359
2360        let filter = prepare_join_filter();
2361        let (columns, batches, metrics) = multi_partitioned_join_collect(
2362            left,
2363            right,
2364            &JoinType::RightSemi,
2365            Some(filter),
2366            task_ctx,
2367        )
2368        .await?;
2369        assert_eq!(columns, vec!["a2", "b2", "c2"]);
2370        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2371            +----+----+----+
2372            | a2 | b2 | c2 |
2373            +----+----+----+
2374            | 2  | 2  | 80 |
2375            +----+----+----+
2376            "#));
2377
2378        assert_join_metrics!(metrics, 1);
2379
2380        Ok(())
2381    }
2382
2383    #[rstest]
2384    #[tokio::test]
2385    async fn join_right_anti_with_filter(
2386        #[values(1, 2, 16)] batch_size: usize,
2387    ) -> Result<()> {
2388        let task_ctx = new_task_ctx(batch_size);
2389        let left = build_left_table();
2390        let right = build_right_table();
2391
2392        let filter = prepare_join_filter();
2393        let (columns, batches, metrics) = multi_partitioned_join_collect(
2394            left,
2395            right,
2396            &JoinType::RightAnti,
2397            Some(filter),
2398            task_ctx,
2399        )
2400        .await?;
2401        assert_eq!(columns, vec!["a2", "b2", "c2"]);
2402        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2403            +----+----+-----+
2404            | a2 | b2 | c2  |
2405            +----+----+-----+
2406            | 10 | 10 | 100 |
2407            | 12 | 10 | 40  |
2408            +----+----+-----+
2409            "#));
2410
2411        assert_join_metrics!(metrics, 2);
2412
2413        Ok(())
2414    }
2415
2416    #[rstest]
2417    #[tokio::test]
2418    async fn join_left_mark_with_filter(
2419        #[values(1, 2, 16)] batch_size: usize,
2420    ) -> Result<()> {
2421        let task_ctx = new_task_ctx(batch_size);
2422        let left = build_left_table();
2423        let right = build_right_table();
2424
2425        let filter = prepare_join_filter();
2426        let (columns, batches, metrics) = multi_partitioned_join_collect(
2427            left,
2428            right,
2429            &JoinType::LeftMark,
2430            Some(filter),
2431            task_ctx,
2432        )
2433        .await?;
2434        assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
2435        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2436            +----+----+-----+-------+
2437            | a1 | b1 | c1  | mark  |
2438            +----+----+-----+-------+
2439            | 11 | 8  | 110 | false |
2440            | 5  | 5  | 50  | true  |
2441            | 9  | 8  | 90  | false |
2442            +----+----+-----+-------+
2443            "#));
2444
2445        assert_join_metrics!(metrics, 3);
2446
2447        Ok(())
2448    }
2449
2450    #[rstest]
2451    #[tokio::test]
2452    async fn join_right_mark_with_filter(
2453        #[values(1, 2, 16)] batch_size: usize,
2454    ) -> Result<()> {
2455        let task_ctx = new_task_ctx(batch_size);
2456        let left = build_left_table();
2457        let right = build_right_table();
2458
2459        let filter = prepare_join_filter();
2460        let (columns, batches, metrics) = multi_partitioned_join_collect(
2461            left,
2462            right,
2463            &JoinType::RightMark,
2464            Some(filter),
2465            task_ctx,
2466        )
2467        .await?;
2468        assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
2469
2470        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2471            +----+----+-----+-------+
2472            | a2 | b2 | c2  | mark  |
2473            +----+----+-----+-------+
2474            | 10 | 10 | 100 | false |
2475            | 12 | 10 | 40  | false |
2476            | 2  | 2  | 80  | true  |
2477            +----+----+-----+-------+
2478            "#));
2479
2480        assert_join_metrics!(metrics, 3);
2481
2482        Ok(())
2483    }
2484
2485    #[tokio::test]
2486    async fn test_overallocation() -> Result<()> {
2487        let left = build_table(
2488            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2489            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2490            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2491            None,
2492            Vec::new(),
2493        );
2494        let right = build_table(
2495            ("a2", &vec![10, 11]),
2496            ("b2", &vec![12, 13]),
2497            ("c2", &vec![14, 15]),
2498            None,
2499            Vec::new(),
2500        );
2501        let filter = prepare_join_filter();
2502
2503        let join_types = vec![
2504            JoinType::Inner,
2505            JoinType::Left,
2506            JoinType::Right,
2507            JoinType::Full,
2508            JoinType::LeftSemi,
2509            JoinType::LeftAnti,
2510            JoinType::LeftMark,
2511            JoinType::RightSemi,
2512            JoinType::RightAnti,
2513            JoinType::RightMark,
2514        ];
2515
2516        for join_type in join_types {
2517            let runtime = RuntimeEnvBuilder::new()
2518                .with_memory_limit(100, 1.0)
2519                .build_arc()?;
2520            let task_ctx = TaskContext::default().with_runtime(runtime);
2521            let task_ctx = Arc::new(task_ctx);
2522
2523            let err = multi_partitioned_join_collect(
2524                Arc::clone(&left),
2525                Arc::clone(&right),
2526                &join_type,
2527                Some(filter.clone()),
2528                task_ctx,
2529            )
2530            .await
2531            .unwrap_err();
2532
2533            assert_contains!(
2534                err.to_string(),
2535                "Resources exhausted: Additional allocation failed for NestedLoopJoinLoad[0] with top memory consumers (across reservations) as:\n  NestedLoopJoinLoad[0]"
2536            );
2537        }
2538
2539        Ok(())
2540    }
2541
2542    /// Returns the column names on the schema
2543    fn columns(schema: &Schema) -> Vec<String> {
2544        schema.fields().iter().map(|f| f.name().clone()).collect()
2545    }
2546}