1use std::fmt::Formatter;
21use std::ops::{BitOr, ControlFlow};
22use std::sync::Arc;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::task::Poll;
25
26use super::utils::{
27 asymmetric_join_output_partitioning, need_produce_result_in_final,
28 reorder_output_after_swap, swap_join_projection,
29};
30use crate::common::can_project;
31use crate::execution_plan::{EmissionType, boundedness_from_children};
32use crate::joins::SharedBitmapBuilder;
33use crate::joins::utils::{
34 BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut,
35 build_join_schema, check_join_is_valid, estimate_join_statistics,
36 need_produce_right_in_final,
37};
38use crate::metrics::{
39 Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, MetricsSet, RatioMetrics,
40};
41use crate::projection::{
42 EmbeddedProjection, JoinData, ProjectionExec, try_embed_projection,
43 try_pushdown_through_join,
44};
45use crate::{
46 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
47 PlanProperties, RecordBatchStream, SendableRecordBatchStream,
48 check_if_same_properties,
49};
50
51use arrow::array::{
52 Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions, UInt32Array,
53 UInt64Array, new_null_array,
54};
55use arrow::buffer::BooleanBuffer;
56use arrow::compute::{
57 BatchCoalescer, concat_batches, filter, filter_record_batch, not, take,
58};
59use arrow::datatypes::{Schema, SchemaRef};
60use arrow::record_batch::RecordBatch;
61use arrow_schema::DataType;
62use datafusion_common::cast::as_boolean_array;
63use datafusion_common::{
64 JoinSide, Result, ScalarValue, Statistics, arrow_err, assert_eq_or_internal_err,
65 internal_datafusion_err, internal_err, project_schema, unwrap_or_internal_err,
66};
67use datafusion_execution::TaskContext;
68use datafusion_execution::disk_manager::RefCountedTempFile;
69use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
70use datafusion_expr::JoinType;
71use datafusion_physical_expr::equivalence::{
72 ProjectionMapping, join_equivalence_properties,
73};
74
75use datafusion_physical_expr::projection::{ProjectionRef, combine_projections};
76use futures::{Stream, StreamExt, TryStreamExt};
77use log::debug;
78use parking_lot::Mutex;
79
80use crate::metrics::SpillMetrics;
81use crate::spill::replayable_spill_input::ReplayableStreamSource;
82use crate::spill::spill_manager::SpillManager;
83
84#[expect(rustdoc::private_intra_doc_links)]
85#[derive(Debug)]
192pub struct NestedLoopJoinExec {
193 pub(crate) left: Arc<dyn ExecutionPlan>,
195 pub(crate) right: Arc<dyn ExecutionPlan>,
197 pub(crate) filter: Option<JoinFilter>,
199 pub(crate) join_type: JoinType,
201 join_schema: SchemaRef,
204 build_side_data: OnceAsync<JoinLeftData>,
211 left_spill_data: Arc<OnceAsync<LeftSpillData>>,
218 column_indices: Vec<ColumnIndex>,
220 projection: Option<ProjectionRef>,
222
223 metrics: ExecutionPlanMetricsSet,
225 cache: Arc<PlanProperties>,
227}
228
229pub struct NestedLoopJoinExecBuilder {
231 left: Arc<dyn ExecutionPlan>,
232 right: Arc<dyn ExecutionPlan>,
233 join_type: JoinType,
234 filter: Option<JoinFilter>,
235 projection: Option<ProjectionRef>,
236}
237
238impl NestedLoopJoinExecBuilder {
239 pub fn new(
241 left: Arc<dyn ExecutionPlan>,
242 right: Arc<dyn ExecutionPlan>,
243 join_type: JoinType,
244 ) -> Self {
245 Self {
246 left,
247 right,
248 join_type,
249 filter: None,
250 projection: None,
251 }
252 }
253
254 pub fn with_projection(self, projection: Option<Vec<usize>>) -> Self {
256 self.with_projection_ref(projection.map(Into::into))
257 }
258
259 pub fn with_projection_ref(mut self, projection: Option<ProjectionRef>) -> Self {
261 self.projection = projection;
262 self
263 }
264
265 pub fn with_filter(mut self, filter: Option<JoinFilter>) -> Self {
267 self.filter = filter;
268 self
269 }
270
271 pub fn build(self) -> Result<NestedLoopJoinExec> {
273 let Self {
274 left,
275 right,
276 join_type,
277 filter,
278 projection,
279 } = self;
280
281 let left_schema = left.schema();
282 let right_schema = right.schema();
283 check_join_is_valid(&left_schema, &right_schema, &[])?;
284 let (join_schema, column_indices) =
285 build_join_schema(&left_schema, &right_schema, &join_type);
286 let join_schema = Arc::new(join_schema);
287 let cache = NestedLoopJoinExec::compute_properties(
288 &left,
289 &right,
290 &join_schema,
291 join_type,
292 projection.as_deref(),
293 )?;
294 Ok(NestedLoopJoinExec {
295 left,
296 right,
297 filter,
298 join_type,
299 join_schema,
300 build_side_data: Default::default(),
301 left_spill_data: Arc::new(OnceAsync::default()),
302 column_indices,
303 projection,
304 metrics: Default::default(),
305 cache: Arc::new(cache),
306 })
307 }
308}
309
310impl From<&NestedLoopJoinExec> for NestedLoopJoinExecBuilder {
311 fn from(exec: &NestedLoopJoinExec) -> Self {
312 Self {
313 left: Arc::clone(exec.left()),
314 right: Arc::clone(exec.right()),
315 join_type: exec.join_type,
316 filter: exec.filter.clone(),
317 projection: exec.projection.clone(),
318 }
319 }
320}
321
322impl NestedLoopJoinExec {
323 pub fn try_new(
325 left: Arc<dyn ExecutionPlan>,
326 right: Arc<dyn ExecutionPlan>,
327 filter: Option<JoinFilter>,
328 join_type: &JoinType,
329 projection: Option<Vec<usize>>,
330 ) -> Result<Self> {
331 NestedLoopJoinExecBuilder::new(left, right, *join_type)
332 .with_projection(projection)
333 .with_filter(filter)
334 .build()
335 }
336
337 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
339 &self.left
340 }
341
342 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
344 &self.right
345 }
346
347 pub fn filter(&self) -> Option<&JoinFilter> {
349 self.filter.as_ref()
350 }
351
352 pub fn join_type(&self) -> &JoinType {
354 &self.join_type
355 }
356
357 pub fn projection(&self) -> &Option<ProjectionRef> {
358 &self.projection
359 }
360
361 fn compute_properties(
363 left: &Arc<dyn ExecutionPlan>,
364 right: &Arc<dyn ExecutionPlan>,
365 schema: &SchemaRef,
366 join_type: JoinType,
367 projection: Option<&[usize]>,
368 ) -> Result<PlanProperties> {
369 let mut eq_properties = join_equivalence_properties(
371 left.equivalence_properties().clone(),
372 right.equivalence_properties().clone(),
373 &join_type,
374 Arc::clone(schema),
375 &Self::maintains_input_order(join_type),
376 None,
377 &[],
379 )?;
380
381 let mut output_partitioning =
382 asymmetric_join_output_partitioning(left, right, &join_type)?;
383
384 let emission_type = if left.boundedness().is_unbounded() {
385 EmissionType::Final
386 } else if right.pipeline_behavior() == EmissionType::Incremental {
387 match join_type {
388 JoinType::Inner
391 | JoinType::LeftSemi
392 | JoinType::RightSemi
393 | JoinType::Right
394 | JoinType::RightAnti
395 | JoinType::RightMark => EmissionType::Incremental,
396 JoinType::Left
399 | JoinType::LeftAnti
400 | JoinType::LeftMark
401 | JoinType::Full => EmissionType::Both,
402 }
403 } else {
404 right.pipeline_behavior()
405 };
406
407 if let Some(projection) = projection {
408 let projection_mapping = ProjectionMapping::from_indices(projection, schema)?;
410 let out_schema = project_schema(schema, Some(&projection))?;
411 output_partitioning =
412 output_partitioning.project(&projection_mapping, &eq_properties);
413 eq_properties = eq_properties.project(&projection_mapping, out_schema);
414 }
415
416 Ok(PlanProperties::new(
417 eq_properties,
418 output_partitioning,
419 emission_type,
420 boundedness_from_children([left, right]),
421 ))
422 }
423
424 fn maintains_input_order(_join_type: JoinType) -> Vec<bool> {
426 vec![false, false]
427 }
428
429 pub fn contains_projection(&self) -> bool {
430 self.projection.is_some()
431 }
432
433 pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
434 let projection = projection.map(Into::into);
435 can_project(&self.schema(), projection.as_deref())?;
437 let projection =
438 combine_projections(projection.as_ref(), self.projection.as_ref())?;
439 NestedLoopJoinExecBuilder::from(self)
440 .with_projection_ref(projection)
441 .build()
442 }
443
444 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
453 let left = self.left();
454 let right = self.right();
455 let new_join = NestedLoopJoinExec::try_new(
456 Arc::clone(right),
457 Arc::clone(left),
458 self.filter().map(JoinFilter::swap),
459 &self.join_type().swap(),
460 swap_join_projection(
461 left.schema().fields().len(),
462 right.schema().fields().len(),
463 self.projection.as_deref(),
464 self.join_type(),
465 ),
466 )?;
467
468 let plan: Arc<dyn ExecutionPlan> = if matches!(
471 self.join_type(),
472 JoinType::LeftSemi
473 | JoinType::RightSemi
474 | JoinType::LeftAnti
475 | JoinType::RightAnti
476 | JoinType::LeftMark
477 | JoinType::RightMark
478 ) || self.projection.is_some()
479 {
480 Arc::new(new_join)
481 } else {
482 reorder_output_after_swap(
483 Arc::new(new_join),
484 &self.left().schema(),
485 &self.right().schema(),
486 )?
487 };
488
489 Ok(plan)
490 }
491
492 fn with_new_children_and_same_properties(
493 &self,
494 mut children: Vec<Arc<dyn ExecutionPlan>>,
495 ) -> Self {
496 let left = children.swap_remove(0);
497 let right = children.swap_remove(0);
498
499 Self {
500 left,
501 right,
502 metrics: ExecutionPlanMetricsSet::new(),
503 build_side_data: Default::default(),
504 left_spill_data: Arc::new(OnceAsync::default()),
505 cache: Arc::clone(&self.cache),
506 filter: self.filter.clone(),
507 join_type: self.join_type,
508 join_schema: Arc::clone(&self.join_schema),
509 column_indices: self.column_indices.clone(),
510 projection: self.projection.clone(),
511 }
512 }
513}
514
515impl DisplayAs for NestedLoopJoinExec {
516 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
517 match t {
518 DisplayFormatType::Default | DisplayFormatType::Verbose => {
519 let display_filter = self.filter.as_ref().map_or_else(
520 || "".to_string(),
521 |f| format!(", filter={}", f.expression()),
522 );
523 let display_projections = if self.contains_projection() {
524 format!(
525 ", projection=[{}]",
526 self.projection
527 .as_ref()
528 .unwrap()
529 .iter()
530 .map(|index| format!(
531 "{}@{}",
532 self.join_schema.fields().get(*index).unwrap().name(),
533 index
534 ))
535 .collect::<Vec<_>>()
536 .join(", ")
537 )
538 } else {
539 "".to_string()
540 };
541 write!(
542 f,
543 "NestedLoopJoinExec: join_type={:?}{}{}",
544 self.join_type, display_filter, display_projections
545 )
546 }
547 DisplayFormatType::TreeRender => {
548 if *self.join_type() != JoinType::Inner {
549 writeln!(f, "join_type={:?}", self.join_type)
550 } else {
551 Ok(())
552 }
553 }
554 }
555 }
556}
557
558impl ExecutionPlan for NestedLoopJoinExec {
559 fn name(&self) -> &'static str {
560 "NestedLoopJoinExec"
561 }
562
563 fn properties(&self) -> &Arc<PlanProperties> {
564 &self.cache
565 }
566
567 fn required_input_distribution(&self) -> Vec<Distribution> {
568 vec![
569 Distribution::SinglePartition,
570 Distribution::UnspecifiedDistribution,
571 ]
572 }
573
574 fn maintains_input_order(&self) -> Vec<bool> {
575 Self::maintains_input_order(self.join_type)
576 }
577
578 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
579 vec![&self.left, &self.right]
580 }
581
582 fn with_new_children(
583 self: Arc<Self>,
584 children: Vec<Arc<dyn ExecutionPlan>>,
585 ) -> Result<Arc<dyn ExecutionPlan>> {
586 check_if_same_properties!(self, children);
587 Ok(Arc::new(
588 NestedLoopJoinExecBuilder::new(
589 Arc::clone(&children[0]),
590 Arc::clone(&children[1]),
591 self.join_type,
592 )
593 .with_filter(self.filter.clone())
594 .with_projection_ref(self.projection.clone())
595 .build()?,
596 ))
597 }
598
599 fn execute(
600 &self,
601 partition: usize,
602 context: Arc<TaskContext>,
603 ) -> Result<SendableRecordBatchStream> {
604 assert_eq_or_internal_err!(
605 self.left.output_partitioning().partition_count(),
606 1,
607 "Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
608 consider using CoalescePartitionsExec or the EnforceDistribution rule"
609 );
610
611 let metrics = NestedLoopJoinMetrics::new(&self.metrics, partition);
612 let batch_size = context.session_config().batch_size();
613
614 let column_indices_after_projection = match self.projection.as_ref() {
616 Some(projection) => projection
617 .iter()
618 .map(|i| self.column_indices[*i].clone())
619 .collect(),
620 None => self.column_indices.clone(),
621 };
622
623 let right_partition_count = self.right().output_partitioning().partition_count();
624
625 let load_reservation =
629 MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
630 .register(context.memory_pool());
631
632 let build_side_data = self.build_side_data.try_once(|| {
633 let stream = self.left.execute(0, Arc::clone(&context))?;
634
635 Ok(collect_left_input(
636 stream,
637 metrics.join_metrics.clone(),
638 load_reservation,
639 need_produce_result_in_final(self.join_type),
640 right_partition_count,
641 ))
642 })?;
643
644 let probe_side_data = self.right.execute(partition, Arc::clone(&context))?;
645
646 let full_join_multi_partition =
662 matches!(self.join_type, JoinType::Full) && right_partition_count > 1;
663 let spill_state = if context.runtime_env().disk_manager.tmp_files_enabled()
664 && !full_join_multi_partition
665 {
666 SpillState::Pending {
667 left_plan: Arc::clone(&self.left),
668 task_context: Arc::clone(&context),
669 left_spill_data: Arc::clone(&self.left_spill_data),
670 }
671 } else {
672 SpillState::Disabled
673 };
674
675 Ok(Box::pin(NestedLoopJoinStream::new(
676 self.schema(),
677 self.filter.clone(),
678 self.join_type,
679 probe_side_data,
680 build_side_data,
681 column_indices_after_projection,
682 metrics,
683 batch_size,
684 spill_state,
685 )))
686 }
687
688 fn metrics(&self) -> Option<MetricsSet> {
689 Some(self.metrics.clone_inner())
690 }
691
692 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
693 let join_columns = Vec::new();
701
702 let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(None)?);
707 let right_stats = Arc::unwrap_or_clone(match partition {
708 Some(partition) => self.right.partition_statistics(Some(partition))?,
709 None => self.right.partition_statistics(None)?,
710 });
711
712 let stats = estimate_join_statistics(
713 left_stats,
714 right_stats,
715 &join_columns,
716 &self.join_type,
717 &self.join_schema,
718 )?;
719
720 Ok(Arc::new(stats.project(self.projection.as_ref())))
721 }
722
723 fn try_swapping_with_projection(
727 &self,
728 projection: &ProjectionExec,
729 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
730 if self.contains_projection() {
732 return Ok(None);
733 }
734
735 let schema = self.schema();
736 if let Some(JoinData {
737 projected_left_child,
738 projected_right_child,
739 join_filter,
740 ..
741 }) = try_pushdown_through_join(
742 projection,
743 self.left(),
744 self.right(),
745 &[],
746 &schema,
747 self.filter(),
748 )? {
749 Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
750 Arc::new(projected_left_child),
751 Arc::new(projected_right_child),
752 join_filter,
753 self.join_type(),
754 None,
756 )?)))
757 } else {
758 try_embed_projection(projection, self)
759 }
760 }
761}
762
763impl EmbeddedProjection for NestedLoopJoinExec {
764 fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
765 self.with_projection(projection)
766 }
767}
768
769pub(crate) struct JoinLeftData {
771 batch: RecordBatch,
773 bitmap: SharedBitmapBuilder,
775 probe_threads_counter: AtomicUsize,
777 #[expect(dead_code)]
781 reservation: MemoryReservation,
782}
783
784impl JoinLeftData {
785 pub(crate) fn new(
786 batch: RecordBatch,
787 bitmap: SharedBitmapBuilder,
788 probe_threads_counter: AtomicUsize,
789 reservation: MemoryReservation,
790 ) -> Self {
791 Self {
792 batch,
793 bitmap,
794 probe_threads_counter,
795 reservation,
796 }
797 }
798
799 pub(crate) fn batch(&self) -> &RecordBatch {
800 &self.batch
801 }
802
803 pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder {
804 &self.bitmap
805 }
806
807 pub(crate) fn report_probe_completed(&self) -> bool {
810 self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
811 }
812}
813
814async fn collect_left_input(
816 stream: SendableRecordBatchStream,
817 join_metrics: BuildProbeJoinMetrics,
818 reservation: MemoryReservation,
819 with_visited_left_side: bool,
820 probe_threads_count: usize,
821) -> Result<JoinLeftData> {
822 let schema = stream.schema();
823
824 let (batches, metrics, reservation) = stream
826 .try_fold(
827 (Vec::new(), join_metrics, reservation),
828 |(mut batches, metrics, reservation), batch| async {
829 let batch_size = batch.get_array_memory_size();
830 reservation.try_grow(batch_size)?;
832 metrics.build_mem_used.add(batch_size);
834 metrics.build_input_batches.add(1);
835 metrics.build_input_rows.add(batch.num_rows());
836 batches.push(batch);
838 Ok((batches, metrics, reservation))
839 },
840 )
841 .await?;
842
843 let merged_batch = concat_batches(&schema, &batches)?;
844
845 let visited_left_side = if with_visited_left_side {
847 let n_rows = merged_batch.num_rows();
848 let buffer_size = n_rows.div_ceil(8);
849 reservation.try_grow(buffer_size)?;
850 metrics.build_mem_used.add(buffer_size);
851
852 let mut buffer = BooleanBufferBuilder::new(n_rows);
853 buffer.append_n(n_rows, false);
854 buffer
855 } else {
856 BooleanBufferBuilder::new(0)
857 };
858
859 Ok(JoinLeftData::new(
860 merged_batch,
861 Mutex::new(visited_left_side),
862 AtomicUsize::new(probe_threads_count),
863 reservation,
864 ))
865}
866
867#[derive(Debug, Clone, Copy)]
870enum NLJState {
871 BufferingLeft,
872 FetchingRight,
873 ProbeRight,
874 EmitRightUnmatched,
875 EmitLeftUnmatched,
876 EmitGlobalRightUnmatched,
881 Done,
882}
883pub(crate) struct LeftSpillData {
889 spill_manager: SpillManager,
891 spill_file: RefCountedTempFile,
893 schema: SchemaRef,
895}
896
897pub(crate) enum SpillState {
904 Disabled,
907
908 Pending {
912 left_plan: Arc<dyn ExecutionPlan>,
914 task_context: Arc<TaskContext>,
916 left_spill_data: Arc<OnceAsync<LeftSpillData>>,
919 },
920
921 Active(Box<SpillStateActive>),
924}
925
926pub(crate) struct SpillStateActive {
929 left_spill_fut: OnceFut<LeftSpillData>,
932 left_stream: Option<SendableRecordBatchStream>,
935 left_schema: Option<SchemaRef>,
937 reservation: MemoryReservation,
939 pending_batches: Vec<RecordBatch>,
941 right_input: ReplayableStreamSource,
943 global_right_bitmaps: Vec<BooleanBuffer>,
947 global_right_bitmaps_reservation: MemoryReservation,
952 right_batch_index: usize,
954}
955
956impl SpillStateActive {
957 fn merge_current_right_bitmap(&mut self, idx: usize, values: BooleanBuffer) {
967 if idx >= self.global_right_bitmaps.len() {
968 let bytes = values.len().div_ceil(8);
975 self.global_right_bitmaps_reservation.grow(bytes);
976 self.global_right_bitmaps.push(values);
977 } else {
978 self.global_right_bitmaps[idx] =
981 self.global_right_bitmaps[idx].bitor(&values);
982 }
983 }
984}
985
986pub(crate) struct NestedLoopJoinStream {
987 pub(crate) output_schema: Arc<Schema>,
998 pub(crate) join_filter: Option<JoinFilter>,
1000 pub(crate) join_type: JoinType,
1002 pub(crate) right_data: Option<SendableRecordBatchStream>,
1005 pub(crate) left_data: OnceFut<JoinLeftData>,
1007 pub(crate) column_indices: Vec<ColumnIndex>,
1020 pub(crate) metrics: NestedLoopJoinMetrics,
1022
1023 batch_size: usize,
1025
1026 should_track_unmatched_right: bool,
1028
1029 state: NLJState,
1035 output_buffer: Box<BatchCoalescer>,
1038 handled_empty_output: bool,
1040
1041 buffered_left_data: Option<Arc<JoinLeftData>>,
1045 left_probe_idx: usize,
1047 left_emit_idx: usize,
1049 left_exhausted: bool,
1052 left_buffered_in_one_pass: bool,
1054
1055 current_right_batch: Option<RecordBatch>,
1059 current_right_batch_matched: Option<BooleanArray>,
1062
1063 spill_state: SpillState,
1065}
1066
1067pub(crate) struct NestedLoopJoinMetrics {
1068 pub(crate) join_metrics: BuildProbeJoinMetrics,
1070 pub(crate) selectivity: RatioMetrics,
1072 pub(crate) spill_metrics: SpillMetrics,
1074}
1075
1076impl NestedLoopJoinMetrics {
1077 pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
1078 Self {
1079 join_metrics: BuildProbeJoinMetrics::new(partition, metrics),
1080 selectivity: MetricBuilder::new(metrics)
1081 .with_type(MetricType::Summary)
1082 .ratio_metrics("selectivity", partition),
1083 spill_metrics: SpillMetrics::new(metrics, partition),
1084 }
1085 }
1086}
1087
1088impl Stream for NestedLoopJoinStream {
1089 type Item = Result<RecordBatch>;
1090
1091 fn poll_next(
1122 mut self: std::pin::Pin<&mut Self>,
1123 cx: &mut std::task::Context<'_>,
1124 ) -> Poll<Option<Self::Item>> {
1125 loop {
1126 match self.state {
1127 NLJState::BufferingLeft => {
1133 debug!("[NLJState] Entering: {:?}", self.state);
1134 let build_metric = self.metrics.join_metrics.build_time.clone();
1139 let _build_timer = build_metric.timer();
1140
1141 match self.handle_buffering_left(cx) {
1142 ControlFlow::Continue(()) => continue,
1143 ControlFlow::Break(poll) => return poll,
1144 }
1145 }
1146
1147 NLJState::FetchingRight => {
1170 debug!("[NLJState] Entering: {:?}", self.state);
1171 let join_metric = self.metrics.join_metrics.join_time.clone();
1173 let _join_timer = join_metric.timer();
1174
1175 match self.handle_fetching_right(cx) {
1176 ControlFlow::Continue(()) => continue,
1177 ControlFlow::Break(poll) => return poll,
1178 }
1179 }
1180
1181 NLJState::ProbeRight => {
1196 debug!("[NLJState] Entering: {:?}", self.state);
1197
1198 let join_metric = self.metrics.join_metrics.join_time.clone();
1200 let _join_timer = join_metric.timer();
1201
1202 match self.handle_probe_right() {
1203 ControlFlow::Continue(()) => continue,
1204 ControlFlow::Break(poll) => {
1205 return self.metrics.join_metrics.baseline.record_poll(poll);
1206 }
1207 }
1208 }
1209
1210 NLJState::EmitRightUnmatched => {
1217 debug!("[NLJState] Entering: {:?}", self.state);
1218
1219 let join_metric = self.metrics.join_metrics.join_time.clone();
1221 let _join_timer = join_metric.timer();
1222
1223 match self.handle_emit_right_unmatched() {
1224 ControlFlow::Continue(()) => continue,
1225 ControlFlow::Break(poll) => {
1226 return self.metrics.join_metrics.baseline.record_poll(poll);
1227 }
1228 }
1229 }
1230
1231 NLJState::EmitLeftUnmatched => {
1247 debug!("[NLJState] Entering: {:?}", self.state);
1248
1249 let join_metric = self.metrics.join_metrics.join_time.clone();
1251 let _join_timer = join_metric.timer();
1252
1253 match self.handle_emit_left_unmatched() {
1254 ControlFlow::Continue(()) => continue,
1255 ControlFlow::Break(poll) => {
1256 return self.metrics.join_metrics.baseline.record_poll(poll);
1257 }
1258 }
1259 }
1260
1261 NLJState::EmitGlobalRightUnmatched => {
1267 debug!("[NLJState] Entering: {:?}", self.state);
1268
1269 let join_metric = self.metrics.join_metrics.join_time.clone();
1270 let _join_timer = join_metric.timer();
1271
1272 match self.handle_emit_global_right_unmatched(cx) {
1273 ControlFlow::Continue(()) => continue,
1274 ControlFlow::Break(poll) => {
1275 return self.metrics.join_metrics.baseline.record_poll(poll);
1276 }
1277 }
1278 }
1279
1280 NLJState::Done => {
1282 debug!("[NLJState] Entering: {:?}", self.state);
1283
1284 let join_metric = self.metrics.join_metrics.join_time.clone();
1286 let _join_timer = join_metric.timer();
1287 let poll = self.handle_done();
1291 return self.metrics.join_metrics.baseline.record_poll(poll);
1292 }
1293 }
1294 }
1295 }
1296}
1297
1298impl RecordBatchStream for NestedLoopJoinStream {
1299 fn schema(&self) -> SchemaRef {
1300 Arc::clone(&self.output_schema)
1301 }
1302}
1303
1304impl NestedLoopJoinStream {
1305 #[expect(clippy::too_many_arguments)]
1306 pub(crate) fn new(
1307 schema: Arc<Schema>,
1308 filter: Option<JoinFilter>,
1309 join_type: JoinType,
1310 right_data: SendableRecordBatchStream,
1311 left_data: OnceFut<JoinLeftData>,
1312 column_indices: Vec<ColumnIndex>,
1313 metrics: NestedLoopJoinMetrics,
1314 batch_size: usize,
1315 spill_state: SpillState,
1316 ) -> Self {
1317 Self {
1318 output_schema: Arc::clone(&schema),
1319 join_filter: filter,
1320 join_type,
1321 right_data: Some(right_data),
1322 column_indices,
1323 left_data,
1324 metrics,
1325 buffered_left_data: None,
1326 output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)),
1327 batch_size,
1328 current_right_batch: None,
1329 current_right_batch_matched: None,
1330 state: NLJState::BufferingLeft,
1331 left_probe_idx: 0,
1332 left_emit_idx: 0,
1333 left_exhausted: false,
1334 left_buffered_in_one_pass: true,
1335 handled_empty_output: false,
1336 should_track_unmatched_right: need_produce_right_in_final(join_type),
1337 spill_state,
1338 }
1339 }
1340
1341 fn is_memory_limited(&self) -> bool {
1343 matches!(self.spill_state, SpillState::Active(_))
1344 }
1345
1346 fn can_fallback_to_spill(&self, error: &datafusion_common::DataFusionError) -> bool {
1348 matches!(self.spill_state, SpillState::Pending { .. })
1349 && matches!(
1350 error.find_root(),
1351 datafusion_common::DataFusionError::ResourcesExhausted(_)
1352 )
1353 }
1354
1355 fn initiate_fallback(&mut self) -> Result<()> {
1361 let (left_plan, context, left_spill_data) =
1363 match std::mem::replace(&mut self.spill_state, SpillState::Disabled) {
1364 SpillState::Pending {
1365 left_plan,
1366 task_context,
1367 left_spill_data,
1368 } => (left_plan, task_context, left_spill_data),
1369 _ => {
1370 return internal_err!(
1371 "initiate_fallback called in non-Pending spill state"
1372 );
1373 }
1374 };
1375
1376 let left_spill_fut = left_spill_data.try_once(|| {
1380 let plan = Arc::clone(&left_plan);
1381 let ctx = Arc::clone(&context);
1382 let spill_metrics = self.metrics.spill_metrics.clone();
1383 Ok(async move {
1384 let mut stream = plan.execute(0, Arc::clone(&ctx))?;
1385 let schema = stream.schema();
1386 let left_spill_manager = SpillManager::new(
1387 ctx.runtime_env(),
1388 spill_metrics,
1389 Arc::clone(&schema),
1390 )
1391 .with_compression_type(ctx.session_config().spill_compression());
1392
1393 let result = left_spill_manager
1394 .spill_record_batch_stream_and_return_max_batch_memory(
1395 &mut stream,
1396 "NestedLoopJoin left spill",
1397 )
1398 .await?;
1399
1400 match result {
1401 Some((file, _max_batch_memory)) => Ok(LeftSpillData {
1402 spill_manager: left_spill_manager,
1403 spill_file: file,
1404 schema,
1405 }),
1406 None => {
1407 internal_err!("Left side produced no data to spill")
1408 }
1409 }
1410 })
1411 })?;
1412
1413 let reservation = MemoryConsumer::new("NestedLoopJoinLoad[fallback]".to_string())
1415 .with_can_spill(true)
1416 .register(context.memory_pool());
1417
1418 let global_right_bitmaps_reservation =
1422 MemoryConsumer::new("NestedLoopJoinGlobalRightBitmaps".to_string())
1423 .register(context.memory_pool());
1424
1425 let right_schema = self
1427 .right_data
1428 .as_ref()
1429 .expect("right_data must be present before fallback")
1430 .schema();
1431 let right_data = self
1432 .right_data
1433 .take()
1434 .expect("right_data must be present before fallback");
1435 let right_spill_manager = SpillManager::new(
1436 context.runtime_env(),
1437 self.metrics.spill_metrics.clone(),
1438 right_schema,
1439 )
1440 .with_compression_type(context.session_config().spill_compression());
1441
1442 self.spill_state = SpillState::Active(Box::new(SpillStateActive {
1443 left_spill_fut,
1444 left_stream: None,
1445 left_schema: None,
1446 reservation,
1447 pending_batches: Vec::new(),
1448 right_input: ReplayableStreamSource::new(
1449 right_data,
1450 right_spill_manager,
1451 "NestedLoopJoin right spill",
1452 ),
1453 global_right_bitmaps: Vec::new(),
1454 global_right_bitmaps_reservation,
1455 right_batch_index: 0,
1456 }));
1457
1458 self.state = NLJState::BufferingLeft;
1461
1462 Ok(())
1463 }
1464
1465 fn handle_buffering_left(
1473 &mut self,
1474 cx: &mut std::task::Context<'_>,
1475 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1476 if self.is_memory_limited() {
1477 self.handle_buffering_left_memory_limited(cx)
1478 } else {
1479 match self.left_data.get_shared(cx) {
1481 Poll::Ready(Ok(left_data)) => {
1482 self.buffered_left_data = Some(left_data);
1483 self.left_exhausted = true;
1484 self.state = NLJState::FetchingRight;
1485 ControlFlow::Continue(())
1486 }
1487 Poll::Ready(Err(e)) => {
1488 if self.can_fallback_to_spill(&e) {
1489 debug!(
1490 "NestedLoopJoin: OnceFut failed with OOM, \
1491 falling back to memory-limited mode"
1492 );
1493 match self.initiate_fallback() {
1494 Ok(()) => ControlFlow::Continue(()),
1495 Err(fallback_err) => {
1496 ControlFlow::Break(Poll::Ready(Some(Err(fallback_err))))
1497 }
1498 }
1499 } else {
1500 ControlFlow::Break(Poll::Ready(Some(Err(e))))
1501 }
1502 }
1503 Poll::Pending => ControlFlow::Break(Poll::Pending),
1504 }
1505 }
1506 }
1507
1508 fn handle_buffering_left_memory_limited(
1514 &mut self,
1515 cx: &mut std::task::Context<'_>,
1516 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1517 let SpillState::Active(active) = &mut self.spill_state else {
1518 unreachable!(
1519 "handle_buffering_left_memory_limited called without Active spill state"
1520 );
1521 };
1522
1523 if active.left_stream.is_none() {
1527 match active.left_spill_fut.get_shared(cx) {
1528 Poll::Ready(Ok(spill_data)) => {
1529 match spill_data
1530 .spill_manager
1531 .read_spill_as_stream(spill_data.spill_file.clone(), None)
1532 {
1533 Ok(stream) => {
1534 active.left_schema = Some(Arc::clone(&spill_data.schema));
1535 active.left_stream = Some(stream);
1536 }
1537 Err(e) => {
1538 return ControlFlow::Break(Poll::Ready(Some(Err(e))));
1539 }
1540 }
1541 }
1542 Poll::Ready(Err(e)) => {
1543 return ControlFlow::Break(Poll::Ready(Some(Err(e))));
1544 }
1545 Poll::Pending => {
1546 return ControlFlow::Break(Poll::Pending);
1547 }
1548 }
1549 }
1550
1551 let left_stream = active
1552 .left_stream
1553 .as_mut()
1554 .expect("left_stream must be set after spill future resolves");
1555
1556 loop {
1560 match left_stream.poll_next_unpin(cx) {
1561 Poll::Ready(Some(Ok(batch))) => {
1562 if batch.num_rows() == 0 {
1563 continue;
1564 }
1565 let batch_rows = batch.num_rows();
1566 let batch_size = batch.get_array_memory_size();
1567 let can_grow = active.reservation.try_grow(batch_size).is_ok();
1568
1569 if !can_grow && !active.pending_batches.is_empty() {
1570 active.pending_batches.push(batch);
1574 self.left_exhausted = false;
1575 self.left_buffered_in_one_pass = false;
1576 break;
1577 } else if !can_grow {
1578 active.reservation.grow(batch_size);
1581 }
1582
1583 self.metrics.join_metrics.build_mem_used.add(batch_size);
1584 self.metrics.join_metrics.build_input_batches.add(1);
1585 self.metrics.join_metrics.build_input_rows.add(batch_rows);
1586 active.pending_batches.push(batch);
1587 }
1588 Poll::Ready(Some(Err(e))) => {
1589 return ControlFlow::Break(Poll::Ready(Some(Err(e))));
1590 }
1591 Poll::Ready(None) => {
1592 self.left_exhausted = true;
1594 break;
1595 }
1596 Poll::Pending => {
1597 return ControlFlow::Break(Poll::Pending);
1598 }
1599 }
1600 }
1601
1602 if self.left_exhausted {
1605 active.left_stream = None;
1606 }
1607
1608 if active.pending_batches.is_empty() {
1609 self.left_exhausted = true;
1611 self.state = NLJState::Done;
1612 return ControlFlow::Continue(());
1613 }
1614
1615 let merged_batch = match concat_batches(
1616 active
1617 .left_schema
1618 .as_ref()
1619 .expect("left_schema must be set"),
1620 &active.pending_batches,
1621 ) {
1622 Ok(batch) => batch,
1623 Err(e) => {
1624 return ControlFlow::Break(Poll::Ready(Some(Err(e.into()))));
1625 }
1626 };
1627 active.pending_batches.clear();
1628
1629 let with_visited = need_produce_result_in_final(self.join_type);
1631 let n_rows = merged_batch.num_rows();
1632 let visited_left_side = if with_visited {
1633 let buffer_size = n_rows.div_ceil(8);
1634 active.reservation.grow(buffer_size);
1636 self.metrics.join_metrics.build_mem_used.add(buffer_size);
1637 let mut buffer = BooleanBufferBuilder::new(n_rows);
1638 buffer.append_n(n_rows, false);
1639 buffer
1640 } else {
1641 BooleanBufferBuilder::new(0)
1642 };
1643
1644 let dummy_reservation = active.reservation.new_empty();
1647
1648 let left_data = JoinLeftData::new(
1649 merged_batch,
1650 Mutex::new(visited_left_side),
1651 AtomicUsize::new(1),
1653 dummy_reservation,
1654 );
1655
1656 self.buffered_left_data = Some(Arc::new(left_data));
1657
1658 active.right_batch_index = 0;
1659 match active.right_input.open_pass() {
1660 Ok(stream) => {
1661 self.right_data = Some(stream);
1662 }
1663 Err(e) => {
1664 return ControlFlow::Break(Poll::Ready(Some(Err(e))));
1665 }
1666 }
1667
1668 self.state = NLJState::FetchingRight;
1669 ControlFlow::Continue(())
1670 }
1671
1672 fn handle_fetching_right(
1677 &mut self,
1678 cx: &mut std::task::Context<'_>,
1679 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1680 match self
1681 .right_data
1682 .as_mut()
1683 .expect("right_data must be present while fetching right")
1684 .poll_next_unpin(cx)
1685 {
1686 Poll::Ready(result) => match result {
1687 Some(Ok(right_batch)) => {
1688 let right_batch_rows = right_batch.num_rows();
1690 self.metrics.join_metrics.input_rows.add(right_batch_rows);
1691 self.metrics.join_metrics.input_batches.add(1);
1692
1693 if right_batch_rows == 0 {
1695 return ControlFlow::Continue(());
1696 }
1697
1698 self.current_right_batch = Some(right_batch);
1699
1700 if self.should_track_unmatched_right {
1702 let zeroed_buf = BooleanBuffer::new_unset(right_batch_rows);
1703 self.current_right_batch_matched =
1704 Some(BooleanArray::new(zeroed_buf, None));
1705 }
1706
1707 self.left_probe_idx = 0;
1708 self.state = NLJState::ProbeRight;
1709 ControlFlow::Continue(())
1710 }
1711 Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1712 None => {
1713 self.state = NLJState::EmitLeftUnmatched;
1714 ControlFlow::Continue(())
1715 }
1716 },
1717 Poll::Pending => ControlFlow::Break(Poll::Pending),
1718 }
1719 }
1720
1721 fn handle_probe_right(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1723 if let Some(poll) = self.maybe_flush_ready_batch() {
1725 return ControlFlow::Break(poll);
1726 }
1727
1728 match self.process_probe_batch() {
1730 Ok(true) => ControlFlow::Continue(()),
1734 Ok(false) => {
1738 self.left_probe_idx = 0;
1740
1741 if let (Ok(left_data), Some(right_batch)) =
1744 (self.get_left_data(), self.current_right_batch.as_ref())
1745 {
1746 let left_rows = left_data.batch().num_rows();
1747 let right_rows = right_batch.num_rows();
1748 self.metrics.selectivity.add_total(left_rows * right_rows);
1749 }
1750
1751 if self.should_track_unmatched_right {
1752 debug_assert!(
1753 self.current_right_batch_matched.is_some(),
1754 "If it's required to track matched rows in the right input, the right bitmap must be present"
1755 );
1756 self.state = NLJState::EmitRightUnmatched;
1757 } else {
1758 self.current_right_batch = None;
1759 self.state = NLJState::FetchingRight;
1760 }
1761 ControlFlow::Continue(())
1762 }
1763 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1764 }
1765 }
1766
1767 fn handle_emit_right_unmatched(
1774 &mut self,
1775 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1776 if self.is_memory_limited() {
1778 debug_assert!(
1779 self.current_right_batch_matched.is_some(),
1780 "right bitmap must be present"
1781 );
1782 let bitmap = std::mem::take(&mut self.current_right_batch_matched)
1783 .expect("right bitmap should be available");
1784 let (values, _nulls) = bitmap.into_parts();
1785
1786 if let SpillState::Active(ref mut active) = self.spill_state {
1787 let idx = active.right_batch_index;
1788 active.merge_current_right_bitmap(idx, values);
1789 active.right_batch_index += 1;
1790 }
1791
1792 self.current_right_batch = None;
1793 self.state = NLJState::FetchingRight;
1794 return ControlFlow::Continue(());
1795 }
1796
1797 if let Some(poll) = self.maybe_flush_ready_batch() {
1800 return ControlFlow::Break(poll);
1801 }
1802
1803 debug_assert!(
1804 self.current_right_batch_matched.is_some()
1805 && self.current_right_batch.is_some(),
1806 "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present"
1807 );
1808 match self.process_right_unmatched() {
1809 Ok(Some(batch)) => match self.output_buffer.push_batch(batch) {
1810 Ok(()) => {
1811 debug_assert!(self.current_right_batch.is_none());
1812 self.state = NLJState::FetchingRight;
1813 ControlFlow::Continue(())
1814 }
1815 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1816 },
1817 Ok(None) => {
1818 debug_assert!(self.current_right_batch.is_none());
1819 self.state = NLJState::FetchingRight;
1820 ControlFlow::Continue(())
1821 }
1822 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1823 }
1824 }
1825
1826 fn handle_emit_left_unmatched(
1832 &mut self,
1833 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1834 if let Some(poll) = self.maybe_flush_ready_batch() {
1836 return ControlFlow::Break(poll);
1837 }
1838
1839 match self.process_left_unmatched() {
1841 Ok(true) => ControlFlow::Continue(()),
1844 Ok(false) => match self.output_buffer.finish_buffered_batch() {
1846 Ok(()) => {
1847 if let Some(poll) = self.maybe_flush_ready_batch() {
1852 return ControlFlow::Break(poll);
1853 }
1854
1855 if !self.left_exhausted && self.is_memory_limited() {
1856 if let SpillState::Active(ref active) = self.spill_state {
1859 active.reservation.resize(0);
1860 }
1861 self.buffered_left_data = None;
1862 self.left_probe_idx = 0;
1863 self.left_emit_idx = 0;
1864 self.state = NLJState::BufferingLeft;
1865 } else if self.is_memory_limited()
1866 && self.should_track_unmatched_right
1867 {
1868 self.right_data = None;
1875 self.state = NLJState::EmitGlobalRightUnmatched;
1876 } else {
1877 self.state = NLJState::Done;
1878 }
1879 ControlFlow::Continue(())
1880 }
1881 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1882 },
1883 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1884 }
1885 }
1886
1887 fn handle_emit_global_right_unmatched(
1892 &mut self,
1893 cx: &mut std::task::Context<'_>,
1894 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1895 if let Some(poll) = self.maybe_flush_ready_batch() {
1897 return ControlFlow::Break(poll);
1898 }
1899
1900 if self.right_data.is_none() {
1902 let SpillState::Active(ref mut active) = self.spill_state else {
1903 unreachable!("EmitGlobalRightUnmatched without Active spill state");
1904 };
1905 active.right_batch_index = 0;
1906 match active.right_input.open_pass() {
1907 Ok(stream) => {
1908 self.right_data = Some(stream);
1909 }
1910 Err(e) => {
1911 return ControlFlow::Break(Poll::Ready(Some(Err(e))));
1912 }
1913 }
1914 }
1915
1916 match self
1918 .right_data
1919 .as_mut()
1920 .expect("right_data must be present")
1921 .poll_next_unpin(cx)
1922 {
1923 Poll::Ready(Some(Ok(right_batch))) => {
1924 if right_batch.num_rows() == 0 {
1925 return ControlFlow::Continue(());
1926 }
1927
1928 let SpillState::Active(ref mut active) = self.spill_state else {
1929 unreachable!();
1930 };
1931 let idx = active.right_batch_index;
1932 active.right_batch_index += 1;
1933
1934 let bitmap = if idx < active.global_right_bitmaps.len() {
1936 BooleanArray::new(active.global_right_bitmaps[idx].clone(), None)
1937 } else {
1938 BooleanArray::new(
1940 BooleanBuffer::new_unset(right_batch.num_rows()),
1941 None,
1942 )
1943 };
1944
1945 let left_schema = Arc::clone(
1946 active
1947 .left_schema
1948 .as_ref()
1949 .expect("left_schema must be set"),
1950 );
1951
1952 match build_unmatched_batch(
1953 &self.output_schema,
1954 &right_batch,
1955 bitmap,
1956 &left_schema,
1957 &self.column_indices,
1958 self.join_type,
1959 JoinSide::Right,
1960 ) {
1961 Ok(Some(batch)) => match self.output_buffer.push_batch(batch) {
1962 Ok(()) => ControlFlow::Continue(()),
1963 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1964 },
1965 Ok(None) => ControlFlow::Continue(()),
1966 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1967 }
1968 }
1969 Poll::Ready(Some(Err(e))) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1970 Poll::Ready(None) => {
1971 match self.output_buffer.finish_buffered_batch() {
1973 Ok(()) => {
1974 self.state = NLJState::Done;
1975 ControlFlow::Continue(())
1976 }
1977 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1978 }
1979 }
1980 Poll::Pending => ControlFlow::Break(Poll::Pending),
1981 }
1982 }
1983
1984 fn handle_done(&mut self) -> Poll<Option<Result<RecordBatch>>> {
1986 if let Some(poll) = self.maybe_flush_ready_batch() {
1988 return poll;
1989 }
1990
1991 if !self.handled_empty_output {
1997 let zero_count = Count::new();
1998 if *self.metrics.join_metrics.baseline.output_rows() == zero_count {
1999 let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema));
2000 self.handled_empty_output = true;
2001 return Poll::Ready(Some(Ok(empty_batch)));
2002 }
2003 }
2004
2005 Poll::Ready(None)
2006 }
2007
2008 fn process_probe_batch(&mut self) -> Result<bool> {
2015 let left_data = Arc::clone(self.get_left_data()?);
2016 let right_batch = self
2017 .current_right_batch
2018 .as_ref()
2019 .ok_or_else(|| internal_datafusion_err!("Right batch should be available"))?
2020 .clone();
2021
2022 if self.left_probe_idx >= left_data.batch().num_rows() {
2024 return Ok(false);
2025 }
2026
2027 debug_assert_ne!(
2040 right_batch.num_rows(),
2041 0,
2042 "When fetching the right batch, empty batches will be skipped"
2043 );
2044
2045 let l_row_cnt_ratio = self.batch_size / right_batch.num_rows();
2046 if l_row_cnt_ratio > 10 {
2047 let l_row_count = std::cmp::min(
2051 l_row_cnt_ratio,
2052 left_data.batch().num_rows() - self.left_probe_idx,
2053 );
2054
2055 debug_assert!(
2056 l_row_count != 0,
2057 "This function should only be entered when there are remaining left rows to process"
2058 );
2059 let joined_batch = self.process_left_range_join(
2060 &left_data,
2061 &right_batch,
2062 self.left_probe_idx,
2063 l_row_count,
2064 )?;
2065
2066 if let Some(batch) = joined_batch {
2067 self.output_buffer.push_batch(batch)?;
2068 }
2069
2070 self.left_probe_idx += l_row_count;
2071
2072 return Ok(true);
2073 }
2074
2075 let l_idx = self.left_probe_idx;
2076 let joined_batch =
2077 self.process_single_left_row_join(&left_data, &right_batch, l_idx)?;
2078
2079 if let Some(batch) = joined_batch {
2080 self.output_buffer.push_batch(batch)?;
2081 }
2082
2083 self.left_probe_idx += 1;
2087
2088 Ok(true)
2090 }
2091
2092 fn process_left_range_join(
2098 &mut self,
2099 left_data: &JoinLeftData,
2100 right_batch: &RecordBatch,
2101 l_start_index: usize,
2102 l_row_count: usize,
2103 ) -> Result<Option<RecordBatch>> {
2104 let right_rows = right_batch.num_rows();
2110 let total_rows = l_row_count * right_rows;
2111
2112 let left_indices: UInt32Array =
2114 UInt32Array::from_iter_values((0..l_row_count).flat_map(|i| {
2115 std::iter::repeat_n((l_start_index + i) as u32, right_rows)
2116 }));
2117 let right_indices: UInt32Array = UInt32Array::from_iter_values(
2118 (0..l_row_count).flat_map(|_| 0..right_rows as u32),
2119 );
2120
2121 debug_assert!(
2122 left_indices.len() == right_indices.len()
2123 && right_indices.len() == total_rows,
2124 "The length or cartesian product should be (left_size * right_size)",
2125 );
2126
2127 let bitmap_combined = if let Some(filter) = &self.join_filter {
2130 let intermediate_batch = if filter.schema.fields().is_empty() {
2132 create_record_batch_with_empty_schema(
2134 Arc::new((*filter.schema).clone()),
2135 total_rows,
2136 )?
2137 } else {
2138 let mut filter_columns: Vec<Arc<dyn Array>> =
2139 Vec::with_capacity(filter.column_indices().len());
2140 for column_index in filter.column_indices() {
2141 let array = if column_index.side == JoinSide::Left {
2142 let col = left_data.batch().column(column_index.index);
2143 take(col.as_ref(), &left_indices, None)?
2144 } else {
2145 let col = right_batch.column(column_index.index);
2146 take(col.as_ref(), &right_indices, None)?
2147 };
2148 filter_columns.push(array);
2149 }
2150
2151 RecordBatch::try_new(Arc::new((*filter.schema).clone()), filter_columns)?
2152 };
2153
2154 let filter_result = filter
2155 .expression()
2156 .evaluate(&intermediate_batch)?
2157 .into_array(intermediate_batch.num_rows())?;
2158 let filter_arr = as_boolean_array(&filter_result)?;
2159
2160 boolean_mask_from_filter(filter_arr)
2162 } else {
2163 BooleanArray::from(vec![true; total_rows])
2165 };
2166
2167 let mut left_bitmap = if need_produce_result_in_final(self.join_type) {
2172 Some(left_data.bitmap().lock())
2173 } else {
2174 None
2175 };
2176
2177 let mut local_right_bitmap = if self.should_track_unmatched_right {
2181 let mut current_right_batch_bitmap = BooleanBufferBuilder::new(right_rows);
2182 current_right_batch_bitmap.append_n(right_rows, false);
2184 Some(current_right_batch_bitmap)
2185 } else {
2186 None
2187 };
2188
2189 for (i, is_matched) in bitmap_combined.iter().enumerate() {
2191 let is_matched = is_matched.ok_or_else(|| {
2192 internal_datafusion_err!("Must be Some after the previous combining step")
2193 })?;
2194
2195 let l_index = l_start_index + i / right_rows;
2196 let r_index = i % right_rows;
2197
2198 if let Some(bitmap) = left_bitmap.as_mut()
2199 && is_matched
2200 {
2201 bitmap.set_bit(l_index, true);
2203 }
2204
2205 if let Some(bitmap) = local_right_bitmap.as_mut()
2206 && is_matched
2207 {
2208 bitmap.set_bit(r_index, true);
2209 }
2210 }
2211
2212 if self.should_track_unmatched_right {
2214 let global_right_bitmap =
2216 std::mem::take(&mut self.current_right_batch_matched).ok_or_else(
2217 || internal_datafusion_err!("right batch's bitmap should be present"),
2218 )?;
2219 let (buf, nulls) = global_right_bitmap.into_parts();
2220 debug_assert!(nulls.is_none());
2221
2222 let current_right_bitmap = local_right_bitmap
2223 .ok_or_else(|| {
2224 internal_datafusion_err!(
2225 "Should be Some if the current join type requires right bitmap"
2226 )
2227 })?
2228 .finish();
2229 let updated_global_right_bitmap = buf.bitor(¤t_right_bitmap);
2230
2231 self.current_right_batch_matched =
2232 Some(BooleanArray::new(updated_global_right_bitmap, None));
2233 }
2234
2235 if matches!(
2237 self.join_type,
2238 JoinType::LeftAnti
2239 | JoinType::LeftSemi
2240 | JoinType::LeftMark
2241 | JoinType::RightAnti
2242 | JoinType::RightMark
2243 | JoinType::RightSemi
2244 ) {
2245 return Ok(None);
2246 }
2247
2248 if self.output_schema.fields().is_empty() {
2251 let row_count = bitmap_combined.true_count();
2253 return Ok(Some(create_record_batch_with_empty_schema(
2254 Arc::clone(&self.output_schema),
2255 row_count,
2256 )?));
2257 }
2258
2259 let mut out_columns: Vec<Arc<dyn Array>> =
2260 Vec::with_capacity(self.output_schema.fields().len());
2261 for column_index in &self.column_indices {
2262 let array = if column_index.side == JoinSide::Left {
2263 let col = left_data.batch().column(column_index.index);
2264 take(col.as_ref(), &left_indices, None)?
2265 } else {
2266 let col = right_batch.column(column_index.index);
2267 take(col.as_ref(), &right_indices, None)?
2268 };
2269 out_columns.push(array);
2270 }
2271 let pre_filtered =
2272 RecordBatch::try_new(Arc::clone(&self.output_schema), out_columns)?;
2273 let filtered = filter_record_batch(&pre_filtered, &bitmap_combined)?;
2274 Ok(Some(filtered))
2275 }
2276
2277 fn process_single_left_row_join(
2283 &mut self,
2284 left_data: &JoinLeftData,
2285 right_batch: &RecordBatch,
2286 l_index: usize,
2287 ) -> Result<Option<RecordBatch>> {
2288 let right_row_count = right_batch.num_rows();
2289 if right_row_count == 0 {
2290 return Ok(None);
2291 }
2292
2293 let cur_right_bitmap = if let Some(filter) = &self.join_filter {
2294 apply_filter_to_row_join_batch(
2295 left_data.batch(),
2296 l_index,
2297 right_batch,
2298 filter,
2299 )?
2300 } else {
2301 BooleanArray::from(vec![true; right_row_count])
2302 };
2303
2304 self.update_matched_bitmap(l_index, &cur_right_bitmap)?;
2305
2306 if matches!(
2309 self.join_type,
2310 JoinType::LeftAnti
2311 | JoinType::LeftSemi
2312 | JoinType::LeftMark
2313 | JoinType::RightAnti
2314 | JoinType::RightMark
2315 | JoinType::RightSemi
2316 ) {
2317 return Ok(None);
2318 }
2319
2320 if !cur_right_bitmap.has_true() {
2321 Ok(None)
2323 } else {
2324 let join_batch = build_row_join_batch(
2326 &self.output_schema,
2327 left_data.batch(),
2328 l_index,
2329 right_batch,
2330 Some(cur_right_bitmap),
2331 &self.column_indices,
2332 JoinSide::Left,
2333 )?;
2334 Ok(join_batch)
2335 }
2336 }
2337
2338 fn process_left_unmatched(&mut self) -> Result<bool> {
2342 let left_data = self.get_left_data()?;
2343 let left_batch = left_data.batch();
2344
2345 let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
2351 let handled_by_other_partition =
2353 self.left_emit_idx == 0 && !left_data.report_probe_completed();
2354 let finished = self.left_emit_idx >= left_batch.num_rows();
2356
2357 if join_type_no_produce_left || handled_by_other_partition || finished {
2358 return Ok(false);
2359 }
2360
2361 let start_idx = self.left_emit_idx;
2366 let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
2367
2368 if let Some(batch) =
2369 self.process_left_unmatched_range(left_data, start_idx, end_idx)?
2370 {
2371 self.output_buffer.push_batch(batch)?;
2372 }
2373
2374 self.left_emit_idx = end_idx;
2376
2377 Ok(true)
2379 }
2380
2381 fn process_left_unmatched_range(
2394 &self,
2395 left_data: &JoinLeftData,
2396 start_idx: usize,
2397 end_idx: usize,
2398 ) -> Result<Option<RecordBatch>> {
2399 if start_idx == end_idx {
2400 return Ok(None);
2401 }
2402
2403 let left_batch = left_data.batch();
2406 let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx);
2407
2408 let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx);
2410 bitmap_sliced.append_n(end_idx - start_idx, false);
2411 let bitmap = left_data.bitmap().lock();
2412 for i in start_idx..end_idx {
2413 assert!(
2414 i - start_idx < bitmap_sliced.capacity(),
2415 "DBG: {start_idx}, {end_idx}"
2416 );
2417 bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i));
2418 }
2419 let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None);
2420
2421 let right_schema = self
2422 .right_data
2423 .as_ref()
2424 .expect("right_data must be present when building unmatched batch")
2425 .schema();
2426 build_unmatched_batch(
2427 &self.output_schema,
2428 &left_batch_sliced,
2429 bitmap_sliced,
2430 &right_schema,
2431 &self.column_indices,
2432 self.join_type,
2433 JoinSide::Left,
2434 )
2435 }
2436
2437 fn process_right_unmatched(&mut self) -> Result<Option<RecordBatch>> {
2440 let right_batch_bitmap: BooleanArray =
2442 std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| {
2443 internal_datafusion_err!("right bitmap should be available")
2444 })?;
2445
2446 let right_batch = self.current_right_batch.take();
2447 let cur_right_batch = unwrap_or_internal_err!(right_batch);
2448
2449 let left_data = self.get_left_data()?;
2450 let left_schema = left_data.batch().schema();
2451
2452 let res = build_unmatched_batch(
2453 &self.output_schema,
2454 &cur_right_batch,
2455 right_batch_bitmap,
2456 &left_schema,
2457 &self.column_indices,
2458 self.join_type,
2459 JoinSide::Right,
2460 );
2461
2462 self.current_right_batch_matched = None;
2464
2465 res
2466 }
2467
2468 fn get_left_data(&self) -> Result<&Arc<JoinLeftData>> {
2472 self.buffered_left_data
2473 .as_ref()
2474 .ok_or_else(|| internal_datafusion_err!("LeftData should be available"))
2475 }
2476
2477 fn maybe_flush_ready_batch(&mut self) -> Option<Poll<Option<Result<RecordBatch>>>> {
2480 if self.output_buffer.has_completed_batch()
2481 && let Some(batch) = self.output_buffer.next_completed_batch()
2482 {
2483 let output_rows = batch.num_rows();
2485 self.metrics.selectivity.add_part(output_rows);
2486
2487 return Some(Poll::Ready(Some(Ok(batch))));
2488 }
2489
2490 None
2491 }
2492
2493 fn update_matched_bitmap(
2509 &mut self,
2510 l_index: usize,
2511 r_matched_bitmap: &BooleanArray,
2512 ) -> Result<()> {
2513 let left_data = self.get_left_data()?;
2514
2515 if need_produce_result_in_final(self.join_type) && r_matched_bitmap.has_true() {
2517 let mut bitmap = left_data.bitmap().lock();
2518 bitmap.set_bit(l_index, true);
2519 }
2520
2521 if self.should_track_unmatched_right {
2523 debug_assert!(self.current_right_batch_matched.is_some());
2524 let right_bitmap = std::mem::take(&mut self.current_right_batch_matched)
2526 .ok_or_else(|| {
2527 internal_datafusion_err!("right batch's bitmap should be present")
2528 })?;
2529 let (buf, nulls) = right_bitmap.into_parts();
2530 debug_assert!(nulls.is_none());
2531 let updated_right_bitmap = buf.bitor(r_matched_bitmap.values());
2532
2533 self.current_right_batch_matched =
2534 Some(BooleanArray::new(updated_right_bitmap, None));
2535 }
2536
2537 Ok(())
2538 }
2539}
2540
2541fn apply_filter_to_row_join_batch(
2547 left_batch: &RecordBatch,
2548 l_index: usize,
2549 right_batch: &RecordBatch,
2550 filter: &JoinFilter,
2551) -> Result<BooleanArray> {
2552 debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0);
2553
2554 let intermediate_batch = if filter.schema.fields().is_empty() {
2555 create_record_batch_with_empty_schema(
2558 Arc::new((*filter.schema).clone()),
2559 right_batch.num_rows(),
2560 )?
2561 } else {
2562 build_row_join_batch(
2563 &filter.schema,
2564 left_batch,
2565 l_index,
2566 right_batch,
2567 None,
2568 &filter.column_indices,
2569 JoinSide::Left,
2570 )?
2571 .ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))?
2572 };
2573
2574 let filter_result = filter
2575 .expression()
2576 .evaluate(&intermediate_batch)?
2577 .into_array(intermediate_batch.num_rows())?;
2578 let filter_arr = as_boolean_array(&filter_result)?;
2579
2580 let bitmap_combined = boolean_mask_from_filter(filter_arr);
2582
2583 Ok(bitmap_combined)
2584}
2585
2586#[inline]
2592fn boolean_mask_from_filter(filter_arr: &BooleanArray) -> BooleanArray {
2593 let (values, nulls) = filter_arr.clone().into_parts();
2594 match nulls {
2595 Some(nulls) => BooleanArray::new(nulls.inner() & &values, None),
2596 None => BooleanArray::new(values, None),
2597 }
2598}
2599
2600fn build_row_join_batch(
2648 output_schema: &Schema,
2649 build_side_batch: &RecordBatch,
2650 build_side_index: usize,
2651 probe_side_batch: &RecordBatch,
2652 probe_side_filter: Option<BooleanArray>,
2653 col_indices: &[ColumnIndex],
2655 build_side: JoinSide,
2658) -> Result<Option<RecordBatch>> {
2659 debug_assert!(build_side != JoinSide::None);
2660
2661 let filtered_probe_batch = if let Some(filter) = probe_side_filter {
2664 &filter_record_batch(probe_side_batch, &filter)?
2665 } else {
2666 probe_side_batch
2667 };
2668
2669 if filtered_probe_batch.num_rows() == 0 {
2670 return Ok(None);
2671 }
2672
2673 if output_schema.fields.is_empty() {
2681 return Ok(Some(create_record_batch_with_empty_schema(
2682 Arc::new(output_schema.clone()),
2683 filtered_probe_batch.num_rows(),
2684 )?));
2685 }
2686
2687 let mut columns: Vec<Arc<dyn Array>> =
2688 Vec::with_capacity(output_schema.fields().len());
2689
2690 for column_index in col_indices {
2691 let array = if column_index.side == build_side {
2692 let original_left_array = build_side_batch.column(column_index.index);
2695
2696 match original_left_array.data_type() {
2702 DataType::List(field) | DataType::LargeList(field)
2703 if field.data_type() == &DataType::Utf8View =>
2704 {
2705 let indices_iter = std::iter::repeat_n(
2706 build_side_index as u64,
2707 filtered_probe_batch.num_rows(),
2708 );
2709 let indices_array = UInt64Array::from_iter_values(indices_iter);
2710 take(original_left_array.as_ref(), &indices_array, None)?
2711 }
2712 _ => {
2713 let scalar_value = ScalarValue::try_from_array(
2714 original_left_array.as_ref(),
2715 build_side_index,
2716 )?;
2717 scalar_value.to_array_of_size(filtered_probe_batch.num_rows())?
2718 }
2719 }
2720 } else {
2721 Arc::clone(filtered_probe_batch.column(column_index.index))
2723 };
2724
2725 columns.push(array);
2726 }
2727
2728 Ok(Some(RecordBatch::try_new(
2729 Arc::new(output_schema.clone()),
2730 columns,
2731 )?))
2732}
2733
2734fn build_unmatched_batch_empty_schema(
2741 output_schema: &SchemaRef,
2742 batch_bitmap: &BooleanArray,
2743 join_type: JoinType,
2745) -> Result<Option<RecordBatch>> {
2746 let result_size = match join_type {
2747 JoinType::Left
2748 | JoinType::Right
2749 | JoinType::Full
2750 | JoinType::LeftAnti
2751 | JoinType::RightAnti => batch_bitmap.false_count(),
2752 JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(),
2753 JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(),
2754 _ => unreachable!(),
2755 };
2756
2757 if output_schema.fields().is_empty() {
2758 Ok(Some(create_record_batch_with_empty_schema(
2759 Arc::clone(output_schema),
2760 result_size,
2761 )?))
2762 } else {
2763 Ok(None)
2764 }
2765}
2766
2767fn create_record_batch_with_empty_schema(
2771 schema: SchemaRef,
2772 row_count: usize,
2773) -> Result<RecordBatch> {
2774 let options = RecordBatchOptions::new()
2775 .with_match_field_names(true)
2776 .with_row_count(Some(row_count));
2777
2778 RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| {
2779 internal_datafusion_err!("Failed to create empty record batch: {}", e)
2780 })
2781}
2782
2783fn build_unmatched_batch(
2819 output_schema: &SchemaRef,
2820 batch: &RecordBatch,
2821 batch_bitmap: BooleanArray,
2822 another_side_schema: &SchemaRef,
2824 col_indices: &[ColumnIndex],
2825 join_type: JoinType,
2826 batch_side: JoinSide,
2827) -> Result<Option<RecordBatch>> {
2828 debug_assert_ne!(join_type, JoinType::Inner);
2830 debug_assert_ne!(batch_side, JoinSide::None);
2831
2832 if let Some(batch) =
2834 build_unmatched_batch_empty_schema(output_schema, &batch_bitmap, join_type)?
2835 {
2836 return Ok(Some(batch));
2837 }
2838
2839 match join_type {
2840 JoinType::Full | JoinType::Right | JoinType::Left => {
2841 if join_type == JoinType::Right {
2842 debug_assert_eq!(batch_side, JoinSide::Right);
2843 }
2844 if join_type == JoinType::Left {
2845 debug_assert_eq!(batch_side, JoinSide::Left);
2846 }
2847
2848 let flipped_bitmap = not(&batch_bitmap)?;
2851
2852 let left_null_columns: Vec<Arc<dyn Array>> = another_side_schema
2854 .fields()
2855 .iter()
2856 .map(|field| new_null_array(field.data_type(), 1))
2857 .collect();
2858
2859 let nullable_left_schema = Arc::new(Schema::new(
2863 another_side_schema
2864 .fields()
2865 .iter()
2866 .map(|field| (**field).clone().with_nullable(true))
2867 .collect::<Vec<_>>(),
2868 ));
2869 let left_null_batch = if nullable_left_schema.fields.is_empty() {
2870 create_record_batch_with_empty_schema(nullable_left_schema, 0)?
2873 } else {
2874 RecordBatch::try_new(nullable_left_schema, left_null_columns)?
2875 };
2876
2877 debug_assert_ne!(batch_side, JoinSide::None);
2878 let opposite_side = batch_side.negate();
2879
2880 build_row_join_batch(
2881 output_schema,
2882 &left_null_batch,
2883 0,
2884 batch,
2885 Some(flipped_bitmap),
2886 col_indices,
2887 opposite_side,
2888 )
2889 }
2890 JoinType::RightSemi
2891 | JoinType::RightAnti
2892 | JoinType::LeftSemi
2893 | JoinType::LeftAnti => {
2894 if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) {
2895 debug_assert_eq!(batch_side, JoinSide::Right);
2896 }
2897 if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
2898 debug_assert_eq!(batch_side, JoinSide::Left);
2899 }
2900
2901 let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi)
2902 {
2903 batch_bitmap.clone()
2904 } else {
2905 not(&batch_bitmap)?
2906 };
2907
2908 if !bitmap.has_true() {
2909 return Ok(None);
2910 }
2911
2912 let mut columns: Vec<Arc<dyn Array>> =
2913 Vec::with_capacity(output_schema.fields().len());
2914
2915 for column_index in col_indices {
2916 debug_assert!(column_index.side == batch_side);
2917
2918 let col = batch.column(column_index.index);
2919 let filtered_col = filter(col, &bitmap)?;
2920
2921 columns.push(filtered_col);
2922 }
2923
2924 Ok(Some(RecordBatch::try_new(
2925 Arc::clone(output_schema),
2926 columns,
2927 )?))
2928 }
2929 JoinType::RightMark | JoinType::LeftMark => {
2930 if join_type == JoinType::RightMark {
2931 debug_assert_eq!(batch_side, JoinSide::Right);
2932 }
2933 if join_type == JoinType::LeftMark {
2934 debug_assert_eq!(batch_side, JoinSide::Left);
2935 }
2936
2937 let mut columns: Vec<Arc<dyn Array>> =
2938 Vec::with_capacity(output_schema.fields().len());
2939
2940 let mut right_batch_bitmap_opt = Some(batch_bitmap);
2942
2943 for column_index in col_indices {
2944 if column_index.side == batch_side {
2945 let col = batch.column(column_index.index);
2946
2947 columns.push(Arc::clone(col));
2948 } else if column_index.side == JoinSide::None {
2949 let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt);
2950 match right_batch_bitmap {
2951 Some(right_batch_bitmap) => {
2952 columns.push(Arc::new(right_batch_bitmap))
2953 }
2954 None => unreachable!("Should only be one mark column"),
2955 }
2956 } else {
2957 return internal_err!(
2958 "Not possible to have this join side for RightMark join"
2959 );
2960 }
2961 }
2962
2963 Ok(Some(RecordBatch::try_new(
2964 Arc::clone(output_schema),
2965 columns,
2966 )?))
2967 }
2968 _ => internal_err!(
2969 "If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"
2970 ),
2971 }
2972}
2973
2974#[cfg(test)]
2975pub(crate) mod tests {
2976 use super::*;
2977 use crate::test::{TestMemoryExec, assert_join_metrics};
2978 use crate::{
2979 common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
2980 };
2981
2982 use arrow::compute::SortOptions;
2983 use arrow::datatypes::{DataType, Field};
2984 use datafusion_common::assert_contains;
2985 use datafusion_common::test_util::batches_to_sort_string;
2986 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2987 use datafusion_expr::Operator;
2988 use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
2989 use datafusion_physical_expr::{Partitioning, PhysicalExpr};
2990 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2991
2992 use insta::allow_duplicates;
2993 use insta::assert_snapshot;
2994 use rstest::rstest;
2995
2996 fn build_table(
2997 a: (&str, &Vec<i32>),
2998 b: (&str, &Vec<i32>),
2999 c: (&str, &Vec<i32>),
3000 batch_size: Option<usize>,
3001 sorted_column_names: Vec<&str>,
3002 ) -> Arc<dyn ExecutionPlan> {
3003 let batch = build_table_i32(a, b, c);
3004 let schema = batch.schema();
3005
3006 let batches = if let Some(batch_size) = batch_size {
3007 let num_batches = batch.num_rows().div_ceil(batch_size);
3008 (0..num_batches)
3009 .map(|i| {
3010 let start = i * batch_size;
3011 let remaining_rows = batch.num_rows() - start;
3012 batch.slice(start, batch_size.min(remaining_rows))
3013 })
3014 .collect::<Vec<_>>()
3015 } else {
3016 vec![batch]
3017 };
3018
3019 let mut sort_info = vec![];
3020 for name in sorted_column_names {
3021 let index = schema.index_of(name).unwrap();
3022 let sort_expr = PhysicalSortExpr::new(
3023 Arc::new(Column::new(name, index)),
3024 SortOptions::new(false, false),
3025 );
3026 sort_info.push(sort_expr);
3027 }
3028 let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap();
3029 if let Some(ordering) = LexOrdering::new(sort_info) {
3030 source = source.try_with_sort_information(vec![ordering]).unwrap();
3031 }
3032
3033 let source = Arc::new(source);
3034 Arc::new(TestMemoryExec::update_cache(&source))
3035 }
3036
3037 fn build_left_table() -> Arc<dyn ExecutionPlan> {
3038 build_table(
3039 ("a1", &vec![5, 9, 11]),
3040 ("b1", &vec![5, 8, 8]),
3041 ("c1", &vec![50, 90, 110]),
3042 None,
3043 Vec::new(),
3044 )
3045 }
3046
3047 fn build_right_table() -> Arc<dyn ExecutionPlan> {
3048 build_table(
3049 ("a2", &vec![12, 2, 10]),
3050 ("b2", &vec![10, 2, 10]),
3051 ("c2", &vec![40, 80, 100]),
3052 None,
3053 Vec::new(),
3054 )
3055 }
3056
3057 fn prepare_join_filter() -> JoinFilter {
3058 let column_indices = vec![
3059 ColumnIndex {
3060 index: 1,
3061 side: JoinSide::Left,
3062 },
3063 ColumnIndex {
3064 index: 1,
3065 side: JoinSide::Right,
3066 },
3067 ];
3068 let intermediate_schema = Schema::new(vec![
3069 Field::new("x", DataType::Int32, true),
3070 Field::new("x", DataType::Int32, true),
3071 ]);
3072 let left_filter = Arc::new(BinaryExpr::new(
3074 Arc::new(Column::new("x", 0)),
3075 Operator::NotEq,
3076 Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
3077 )) as Arc<dyn PhysicalExpr>;
3078 let right_filter = Arc::new(BinaryExpr::new(
3080 Arc::new(Column::new("x", 1)),
3081 Operator::NotEq,
3082 Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
3083 )) as Arc<dyn PhysicalExpr>;
3084 let filter_expression =
3095 Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
3096 as Arc<dyn PhysicalExpr>;
3097
3098 JoinFilter::new(
3099 filter_expression,
3100 column_indices,
3101 Arc::new(intermediate_schema),
3102 )
3103 }
3104
3105 pub(crate) async fn multi_partitioned_join_collect(
3106 left: Arc<dyn ExecutionPlan>,
3107 right: Arc<dyn ExecutionPlan>,
3108 join_type: &JoinType,
3109 join_filter: Option<JoinFilter>,
3110 context: Arc<TaskContext>,
3111 ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
3112 let partition_count = 4;
3113
3114 let right = Arc::new(RepartitionExec::try_new(
3116 right,
3117 Partitioning::RoundRobinBatch(partition_count),
3118 )?) as Arc<dyn ExecutionPlan>;
3119
3120 let nested_loop_join =
3122 NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
3123 let columns = columns(&nested_loop_join.schema());
3124 let mut batches = vec![];
3125 for i in 0..partition_count {
3126 let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
3127 let more_batches = common::collect(stream).await?;
3128 batches.extend(
3129 more_batches
3130 .into_iter()
3131 .inspect(|b| {
3132 assert!(b.num_rows() <= context.session_config().batch_size())
3133 })
3134 .filter(|b| b.num_rows() > 0)
3135 .collect::<Vec<_>>(),
3136 );
3137 }
3138
3139 let metrics = nested_loop_join.metrics().unwrap();
3140
3141 Ok((columns, batches, metrics))
3142 }
3143
3144 fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
3145 let base = TaskContext::default();
3146 let cfg = base.session_config().clone().with_batch_size(batch_size);
3148 Arc::new(base.with_session_config(cfg))
3149 }
3150
3151 #[rstest]
3152 #[tokio::test]
3153 async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
3154 let task_ctx = new_task_ctx(batch_size);
3155 dbg!(&batch_size);
3156 let left = build_left_table();
3157 let right = build_right_table();
3158 let filter = prepare_join_filter();
3159 let (columns, batches, metrics) = multi_partitioned_join_collect(
3160 left,
3161 right,
3162 &JoinType::Inner,
3163 Some(filter),
3164 task_ctx,
3165 )
3166 .await?;
3167
3168 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3169 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3170 +----+----+----+----+----+----+
3171 | a1 | b1 | c1 | a2 | b2 | c2 |
3172 +----+----+----+----+----+----+
3173 | 5 | 5 | 50 | 2 | 2 | 80 |
3174 +----+----+----+----+----+----+
3175 "));
3176
3177 assert_join_metrics!(metrics, 1);
3178
3179 Ok(())
3180 }
3181
3182 #[rstest]
3183 #[tokio::test]
3184 async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
3185 let task_ctx = new_task_ctx(batch_size);
3186 let left = build_left_table();
3187 let right = build_right_table();
3188
3189 let filter = prepare_join_filter();
3190 let (columns, batches, metrics) = multi_partitioned_join_collect(
3191 left,
3192 right,
3193 &JoinType::Left,
3194 Some(filter),
3195 task_ctx,
3196 )
3197 .await?;
3198 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3199 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3200 +----+----+-----+----+----+----+
3201 | a1 | b1 | c1 | a2 | b2 | c2 |
3202 +----+----+-----+----+----+----+
3203 | 11 | 8 | 110 | | | |
3204 | 5 | 5 | 50 | 2 | 2 | 80 |
3205 | 9 | 8 | 90 | | | |
3206 +----+----+-----+----+----+----+
3207 "));
3208
3209 assert_join_metrics!(metrics, 3);
3210
3211 Ok(())
3212 }
3213
3214 #[rstest]
3215 #[tokio::test]
3216 async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
3217 let task_ctx = new_task_ctx(batch_size);
3218 let left = build_left_table();
3219 let right = build_right_table();
3220
3221 let filter = prepare_join_filter();
3222 let (columns, batches, metrics) = multi_partitioned_join_collect(
3223 left,
3224 right,
3225 &JoinType::Right,
3226 Some(filter),
3227 task_ctx,
3228 )
3229 .await?;
3230 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3231 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3232 +----+----+----+----+----+-----+
3233 | a1 | b1 | c1 | a2 | b2 | c2 |
3234 +----+----+----+----+----+-----+
3235 | | | | 10 | 10 | 100 |
3236 | | | | 12 | 10 | 40 |
3237 | 5 | 5 | 50 | 2 | 2 | 80 |
3238 +----+----+----+----+----+-----+
3239 "));
3240
3241 assert_join_metrics!(metrics, 3);
3242
3243 Ok(())
3244 }
3245
3246 #[rstest]
3247 #[tokio::test]
3248 async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
3249 let task_ctx = new_task_ctx(batch_size);
3250 let left = build_left_table();
3251 let right = build_right_table();
3252
3253 let filter = prepare_join_filter();
3254 let (columns, batches, metrics) = multi_partitioned_join_collect(
3255 left,
3256 right,
3257 &JoinType::Full,
3258 Some(filter),
3259 task_ctx,
3260 )
3261 .await?;
3262 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3263 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3264 +----+----+-----+----+----+-----+
3265 | a1 | b1 | c1 | a2 | b2 | c2 |
3266 +----+----+-----+----+----+-----+
3267 | | | | 10 | 10 | 100 |
3268 | | | | 12 | 10 | 40 |
3269 | 11 | 8 | 110 | | | |
3270 | 5 | 5 | 50 | 2 | 2 | 80 |
3271 | 9 | 8 | 90 | | | |
3272 +----+----+-----+----+----+-----+
3273 "));
3274
3275 assert_join_metrics!(metrics, 5);
3276
3277 Ok(())
3278 }
3279
3280 #[rstest]
3281 #[tokio::test]
3282 async fn join_left_semi_with_filter(
3283 #[values(1, 2, 16)] batch_size: usize,
3284 ) -> Result<()> {
3285 let task_ctx = new_task_ctx(batch_size);
3286 let left = build_left_table();
3287 let right = build_right_table();
3288
3289 let filter = prepare_join_filter();
3290 let (columns, batches, metrics) = multi_partitioned_join_collect(
3291 left,
3292 right,
3293 &JoinType::LeftSemi,
3294 Some(filter),
3295 task_ctx,
3296 )
3297 .await?;
3298 assert_eq!(columns, vec!["a1", "b1", "c1"]);
3299 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3300 +----+----+----+
3301 | a1 | b1 | c1 |
3302 +----+----+----+
3303 | 5 | 5 | 50 |
3304 +----+----+----+
3305 "));
3306
3307 assert_join_metrics!(metrics, 1);
3308
3309 Ok(())
3310 }
3311
3312 #[rstest]
3313 #[tokio::test]
3314 async fn join_left_anti_with_filter(
3315 #[values(1, 2, 16)] batch_size: usize,
3316 ) -> Result<()> {
3317 let task_ctx = new_task_ctx(batch_size);
3318 let left = build_left_table();
3319 let right = build_right_table();
3320
3321 let filter = prepare_join_filter();
3322 let (columns, batches, metrics) = multi_partitioned_join_collect(
3323 left,
3324 right,
3325 &JoinType::LeftAnti,
3326 Some(filter),
3327 task_ctx,
3328 )
3329 .await?;
3330 assert_eq!(columns, vec!["a1", "b1", "c1"]);
3331 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3332 +----+----+-----+
3333 | a1 | b1 | c1 |
3334 +----+----+-----+
3335 | 11 | 8 | 110 |
3336 | 9 | 8 | 90 |
3337 +----+----+-----+
3338 "));
3339
3340 assert_join_metrics!(metrics, 2);
3341
3342 Ok(())
3343 }
3344
3345 #[tokio::test]
3346 async fn join_has_correct_stats() -> Result<()> {
3347 let left = build_left_table();
3348 let right = build_right_table();
3349 let nested_loop_join = NestedLoopJoinExec::try_new(
3350 left,
3351 right,
3352 None,
3353 &JoinType::Left,
3354 Some(vec![1, 2]),
3355 )?;
3356 let stats = nested_loop_join.partition_statistics(None)?;
3357 assert_eq!(
3358 nested_loop_join.schema().fields().len(),
3359 stats.column_statistics.len(),
3360 );
3361 assert_eq!(2, stats.column_statistics.len());
3362 Ok(())
3363 }
3364
3365 #[rstest]
3366 #[tokio::test]
3367 async fn join_right_semi_with_filter(
3368 #[values(1, 2, 16)] batch_size: usize,
3369 ) -> Result<()> {
3370 let task_ctx = new_task_ctx(batch_size);
3371 let left = build_left_table();
3372 let right = build_right_table();
3373
3374 let filter = prepare_join_filter();
3375 let (columns, batches, metrics) = multi_partitioned_join_collect(
3376 left,
3377 right,
3378 &JoinType::RightSemi,
3379 Some(filter),
3380 task_ctx,
3381 )
3382 .await?;
3383 assert_eq!(columns, vec!["a2", "b2", "c2"]);
3384 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3385 +----+----+----+
3386 | a2 | b2 | c2 |
3387 +----+----+----+
3388 | 2 | 2 | 80 |
3389 +----+----+----+
3390 "));
3391
3392 assert_join_metrics!(metrics, 1);
3393
3394 Ok(())
3395 }
3396
3397 #[rstest]
3398 #[tokio::test]
3399 async fn join_right_anti_with_filter(
3400 #[values(1, 2, 16)] batch_size: usize,
3401 ) -> Result<()> {
3402 let task_ctx = new_task_ctx(batch_size);
3403 let left = build_left_table();
3404 let right = build_right_table();
3405
3406 let filter = prepare_join_filter();
3407 let (columns, batches, metrics) = multi_partitioned_join_collect(
3408 left,
3409 right,
3410 &JoinType::RightAnti,
3411 Some(filter),
3412 task_ctx,
3413 )
3414 .await?;
3415 assert_eq!(columns, vec!["a2", "b2", "c2"]);
3416 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3417 +----+----+-----+
3418 | a2 | b2 | c2 |
3419 +----+----+-----+
3420 | 10 | 10 | 100 |
3421 | 12 | 10 | 40 |
3422 +----+----+-----+
3423 "));
3424
3425 assert_join_metrics!(metrics, 2);
3426
3427 Ok(())
3428 }
3429
3430 #[rstest]
3431 #[tokio::test]
3432 async fn join_left_mark_with_filter(
3433 #[values(1, 2, 16)] batch_size: usize,
3434 ) -> Result<()> {
3435 let task_ctx = new_task_ctx(batch_size);
3436 let left = build_left_table();
3437 let right = build_right_table();
3438
3439 let filter = prepare_join_filter();
3440 let (columns, batches, metrics) = multi_partitioned_join_collect(
3441 left,
3442 right,
3443 &JoinType::LeftMark,
3444 Some(filter),
3445 task_ctx,
3446 )
3447 .await?;
3448 assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
3449 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3450 +----+----+-----+-------+
3451 | a1 | b1 | c1 | mark |
3452 +----+----+-----+-------+
3453 | 11 | 8 | 110 | false |
3454 | 5 | 5 | 50 | true |
3455 | 9 | 8 | 90 | false |
3456 +----+----+-----+-------+
3457 "));
3458
3459 assert_join_metrics!(metrics, 3);
3460
3461 Ok(())
3462 }
3463
3464 #[rstest]
3465 #[tokio::test]
3466 async fn join_right_mark_with_filter(
3467 #[values(1, 2, 16)] batch_size: usize,
3468 ) -> Result<()> {
3469 let task_ctx = new_task_ctx(batch_size);
3470 let left = build_left_table();
3471 let right = build_right_table();
3472
3473 let filter = prepare_join_filter();
3474 let (columns, batches, metrics) = multi_partitioned_join_collect(
3475 left,
3476 right,
3477 &JoinType::RightMark,
3478 Some(filter),
3479 task_ctx,
3480 )
3481 .await?;
3482 assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
3483
3484 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3485 +----+----+-----+-------+
3486 | a2 | b2 | c2 | mark |
3487 +----+----+-----+-------+
3488 | 10 | 10 | 100 | false |
3489 | 12 | 10 | 40 | false |
3490 | 2 | 2 | 80 | true |
3491 +----+----+-----+-------+
3492 "));
3493
3494 assert_join_metrics!(metrics, 3);
3495
3496 Ok(())
3497 }
3498
3499 #[tokio::test]
3500 async fn test_overallocation() -> Result<()> {
3501 let left = build_table(
3502 ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
3503 ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
3504 ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
3505 None,
3506 Vec::new(),
3507 );
3508 let right = build_table(
3509 ("a2", &vec![10, 11]),
3510 ("b2", &vec![12, 13]),
3511 ("c2", &vec![14, 15]),
3512 None,
3513 Vec::new(),
3514 );
3515 let filter = prepare_join_filter();
3516
3517 let fallback_join_types = vec![
3520 JoinType::Inner,
3521 JoinType::Left,
3522 JoinType::LeftSemi,
3523 JoinType::LeftAnti,
3524 JoinType::LeftMark,
3525 JoinType::Right,
3526 JoinType::RightSemi,
3527 JoinType::RightAnti,
3528 JoinType::RightMark,
3529 ];
3530
3531 for join_type in &fallback_join_types {
3532 let runtime = RuntimeEnvBuilder::new()
3533 .with_memory_limit(100, 1.0)
3534 .build_arc()?;
3535 let task_ctx = TaskContext::default().with_runtime(runtime);
3536 let task_ctx = Arc::new(task_ctx);
3537
3538 let _result = multi_partitioned_join_collect(
3540 Arc::clone(&left),
3541 Arc::clone(&right),
3542 join_type,
3543 Some(filter.clone()),
3544 task_ctx,
3545 )
3546 .await?;
3547 }
3548
3549 let runtime = RuntimeEnvBuilder::new()
3553 .with_memory_limit(100, 1.0)
3554 .build_arc()?;
3555 let task_ctx = TaskContext::default().with_runtime(runtime);
3556 let task_ctx = Arc::new(task_ctx);
3557 let err = multi_partitioned_join_collect(
3558 Arc::clone(&left),
3559 Arc::clone(&right),
3560 &JoinType::Full,
3561 Some(filter.clone()),
3562 task_ctx,
3563 )
3564 .await
3565 .unwrap_err();
3566 assert_contains!(err.to_string(), "Resources exhausted");
3567
3568 Ok(())
3569 }
3570
3571 fn columns(schema: &Schema) -> Vec<String> {
3573 schema.fields().iter().map(|f| f.name().clone()).collect()
3574 }
3575
3576 async fn join_collect(
3582 left: Arc<dyn ExecutionPlan>,
3583 right: Arc<dyn ExecutionPlan>,
3584 join_type: &JoinType,
3585 join_filter: Option<JoinFilter>,
3586 context: Arc<TaskContext>,
3587 ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
3588 let nested_loop_join =
3589 NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
3590 let columns = columns(&nested_loop_join.schema());
3591 let stream = nested_loop_join.execute(0, context)?;
3592 let batches: Vec<RecordBatch> = common::collect(stream)
3593 .await?
3594 .into_iter()
3595 .filter(|b| b.num_rows() > 0)
3596 .collect();
3597 let metrics = nested_loop_join.metrics().unwrap();
3598 Ok((columns, batches, metrics))
3599 }
3600
3601 fn task_ctx_with_memory_limit(
3603 memory_limit: usize,
3604 batch_size: usize,
3605 ) -> Result<Arc<TaskContext>> {
3606 let runtime = RuntimeEnvBuilder::new()
3607 .with_memory_limit(memory_limit, 1.0)
3608 .build_arc()?;
3609 let cfg = TaskContext::default()
3610 .session_config()
3611 .clone()
3612 .with_batch_size(batch_size);
3613 let task_ctx = TaskContext::default()
3614 .with_runtime(runtime)
3615 .with_session_config(cfg);
3616 Ok(Arc::new(task_ctx))
3617 }
3618
3619 #[tokio::test]
3620 async fn test_nlj_memory_limited_inner_join() -> Result<()> {
3621 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3623 let left = build_left_table();
3624 let right = build_right_table();
3625 let filter = prepare_join_filter();
3626
3627 let (columns, batches, metrics) =
3628 join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx).await?;
3629
3630 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3631
3632 assert!(
3634 metrics.spill_count().unwrap_or(0) > 0,
3635 "Expected spilling to occur under tight memory limit"
3636 );
3637
3638 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3640 +----+----+----+----+----+----+
3641 | a1 | b1 | c1 | a2 | b2 | c2 |
3642 +----+----+----+----+----+----+
3643 | 5 | 5 | 50 | 2 | 2 | 80 |
3644 +----+----+----+----+----+----+
3645 "));
3646 Ok(())
3647 }
3648
3649 #[tokio::test]
3650 async fn test_nlj_memory_limited_left_join() -> Result<()> {
3651 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3652 let left = build_left_table();
3653 let right = build_right_table();
3654 let filter = prepare_join_filter();
3655
3656 let (columns, batches, metrics) =
3657 join_collect(left, right, &JoinType::Left, Some(filter), task_ctx).await?;
3658
3659 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3660
3661 assert!(
3663 metrics.spill_count().unwrap_or(0) > 0,
3664 "Expected spilling to occur under tight memory limit"
3665 );
3666
3667 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3668 +----+----+-----+----+----+----+
3669 | a1 | b1 | c1 | a2 | b2 | c2 |
3670 +----+----+-----+----+----+----+
3671 | 11 | 8 | 110 | | | |
3672 | 5 | 5 | 50 | 2 | 2 | 80 |
3673 | 9 | 8 | 90 | | | |
3674 +----+----+-----+----+----+----+
3675 "));
3676 Ok(())
3677 }
3678
3679 #[tokio::test]
3680 async fn test_nlj_fits_in_memory_no_spill() -> Result<()> {
3681 let task_ctx = task_ctx_with_memory_limit(10_000_000, 16)?;
3683 let left = build_left_table();
3684 let right = build_right_table();
3685 let filter = prepare_join_filter();
3686
3687 let (columns, batches, metrics) =
3688 join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx).await?;
3689
3690 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3691
3692 assert_eq!(
3694 metrics.spill_count().unwrap_or(0),
3695 0,
3696 "Expected no spilling with generous memory limit"
3697 );
3698
3699 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3700 +----+----+----+----+----+----+
3701 | a1 | b1 | c1 | a2 | b2 | c2 |
3702 +----+----+----+----+----+----+
3703 | 5 | 5 | 50 | 2 | 2 | 80 |
3704 +----+----+----+----+----+----+
3705 "));
3706 Ok(())
3707 }
3708
3709 #[tokio::test]
3710 async fn test_nlj_memory_limited_empty_inputs() -> Result<()> {
3711 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3712
3713 let empty_left = build_table(
3715 ("a1", &vec![]),
3716 ("b1", &vec![]),
3717 ("c1", &vec![]),
3718 None,
3719 Vec::new(),
3720 );
3721 let right = build_right_table();
3722 let filter = prepare_join_filter();
3723
3724 let (_columns, batches, _metrics) =
3725 join_collect(empty_left, right, &JoinType::Inner, Some(filter), task_ctx)
3726 .await?;
3727 assert!(batches.is_empty() || batches.iter().all(|b| b.num_rows() == 0));
3728
3729 let task_ctx2 = task_ctx_with_memory_limit(50, 16)?;
3731 let left = build_left_table();
3732 let empty_right = build_table(
3733 ("a2", &vec![]),
3734 ("b2", &vec![]),
3735 ("c2", &vec![]),
3736 None,
3737 Vec::new(),
3738 );
3739 let filter2 = prepare_join_filter();
3740
3741 let (_columns, batches, _metrics) = join_collect(
3742 left,
3743 empty_right,
3744 &JoinType::Inner,
3745 Some(filter2),
3746 task_ctx2,
3747 )
3748 .await?;
3749 assert!(batches.is_empty() || batches.iter().all(|b| b.num_rows() == 0));
3750
3751 Ok(())
3752 }
3753
3754 #[tokio::test]
3755 async fn test_nlj_memory_limited_no_disk_falls_back_to_oom() -> Result<()> {
3756 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
3758
3759 let runtime = RuntimeEnvBuilder::new()
3760 .with_memory_limit(100, 1.0)
3761 .with_disk_manager_builder(
3762 DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
3763 )
3764 .build_arc()?;
3765 let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime));
3766
3767 let left = build_left_table();
3768 let right = build_right_table();
3769 let filter = prepare_join_filter();
3770
3771 let err = join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx)
3772 .await
3773 .unwrap_err();
3774
3775 assert_contains!(err.to_string(), "Resources exhausted");
3776 Ok(())
3777 }
3778
3779 #[tokio::test]
3780 async fn test_nlj_memory_limited_right_join() -> Result<()> {
3781 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3782 let left = build_left_table();
3783 let right = build_right_table();
3784 let filter = prepare_join_filter();
3785
3786 let (columns, batches, metrics) =
3787 join_collect(left, right, &JoinType::Right, Some(filter), task_ctx).await?;
3788
3789 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3790
3791 assert!(
3793 metrics.spill_count().unwrap_or(0) > 0,
3794 "Expected spilling to occur under tight memory limit"
3795 );
3796
3797 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3799 +----+----+----+----+----+-----+
3800 | a1 | b1 | c1 | a2 | b2 | c2 |
3801 +----+----+----+----+----+-----+
3802 | | | | 10 | 10 | 100 |
3803 | | | | 12 | 10 | 40 |
3804 | 5 | 5 | 50 | 2 | 2 | 80 |
3805 +----+----+----+----+----+-----+
3806 "));
3807 Ok(())
3808 }
3809
3810 #[tokio::test]
3811 async fn test_nlj_memory_limited_full_join() -> Result<()> {
3812 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3813 let left = build_left_table();
3814 let right = build_right_table();
3815 let filter = prepare_join_filter();
3816
3817 let (columns, batches, metrics) =
3818 join_collect(left, right, &JoinType::Full, Some(filter), task_ctx).await?;
3819
3820 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
3821
3822 assert!(
3824 metrics.spill_count().unwrap_or(0) > 0,
3825 "Expected spilling to occur under tight memory limit"
3826 );
3827
3828 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3830 +----+----+-----+----+----+-----+
3831 | a1 | b1 | c1 | a2 | b2 | c2 |
3832 +----+----+-----+----+----+-----+
3833 | | | | 10 | 10 | 100 |
3834 | | | | 12 | 10 | 40 |
3835 | 11 | 8 | 110 | | | |
3836 | 5 | 5 | 50 | 2 | 2 | 80 |
3837 | 9 | 8 | 90 | | | |
3838 +----+----+-----+----+----+-----+
3839 "));
3840 Ok(())
3841 }
3842
3843 #[tokio::test]
3844 async fn test_nlj_memory_limited_right_semi_join() -> Result<()> {
3845 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3846 let left = build_left_table();
3847 let right = build_right_table();
3848 let filter = prepare_join_filter();
3849
3850 let (columns, batches, metrics) =
3851 join_collect(left, right, &JoinType::RightSemi, Some(filter), task_ctx)
3852 .await?;
3853
3854 assert_eq!(columns, vec!["a2", "b2", "c2"]);
3855
3856 assert!(
3857 metrics.spill_count().unwrap_or(0) > 0,
3858 "Expected spilling to occur under tight memory limit"
3859 );
3860
3861 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3863 +----+----+----+
3864 | a2 | b2 | c2 |
3865 +----+----+----+
3866 | 2 | 2 | 80 |
3867 +----+----+----+
3868 "));
3869 Ok(())
3870 }
3871
3872 #[tokio::test]
3873 async fn test_nlj_memory_limited_right_anti_join() -> Result<()> {
3874 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3875 let left = build_left_table();
3876 let right = build_right_table();
3877 let filter = prepare_join_filter();
3878
3879 let (columns, batches, metrics) =
3880 join_collect(left, right, &JoinType::RightAnti, Some(filter), task_ctx)
3881 .await?;
3882
3883 assert_eq!(columns, vec!["a2", "b2", "c2"]);
3884
3885 assert!(
3886 metrics.spill_count().unwrap_or(0) > 0,
3887 "Expected spilling to occur under tight memory limit"
3888 );
3889
3890 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3892 +----+----+-----+
3893 | a2 | b2 | c2 |
3894 +----+----+-----+
3895 | 10 | 10 | 100 |
3896 | 12 | 10 | 40 |
3897 +----+----+-----+
3898 "));
3899 Ok(())
3900 }
3901
3902 #[tokio::test]
3903 async fn test_nlj_memory_limited_right_mark_join() -> Result<()> {
3904 let task_ctx = task_ctx_with_memory_limit(50, 16)?;
3905 let left = build_left_table();
3906 let right = build_right_table();
3907 let filter = prepare_join_filter();
3908
3909 let (columns, batches, metrics) =
3910 join_collect(left, right, &JoinType::RightMark, Some(filter), task_ctx)
3911 .await?;
3912
3913 assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
3914
3915 assert!(
3916 metrics.spill_count().unwrap_or(0) > 0,
3917 "Expected spilling to occur under tight memory limit"
3918 );
3919
3920 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
3922 +----+----+-----+-------+
3923 | a2 | b2 | c2 | mark |
3924 +----+----+-----+-------+
3925 | 10 | 10 | 100 | false |
3926 | 12 | 10 | 40 | false |
3927 | 2 | 2 | 80 | true |
3928 +----+----+-----+-------+
3929 "));
3930 Ok(())
3931 }
3932}