1use std::any::Any;
21use std::fmt::Formatter;
22use std::ops::{BitOr, ControlFlow};
23use std::sync::Arc;
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::task::Poll;
26
27use super::utils::{
28 asymmetric_join_output_partitioning, need_produce_result_in_final,
29 reorder_output_after_swap, swap_join_projection,
30};
31use crate::common::can_project;
32use crate::execution_plan::{EmissionType, boundedness_from_children};
33use crate::joins::SharedBitmapBuilder;
34use crate::joins::utils::{
35 BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut,
36 build_join_schema, check_join_is_valid, estimate_join_statistics,
37 need_produce_right_in_final,
38};
39use crate::metrics::{
40 Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, MetricsSet, RatioMetrics,
41};
42use crate::projection::{
43 EmbeddedProjection, JoinData, ProjectionExec, try_embed_projection,
44 try_pushdown_through_join,
45};
46use crate::{
47 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
48 PlanProperties, RecordBatchStream, SendableRecordBatchStream,
49 check_if_same_properties,
50};
51
52use arrow::array::{
53 Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions, UInt32Array,
54 UInt64Array, new_null_array,
55};
56use arrow::buffer::BooleanBuffer;
57use arrow::compute::{
58 BatchCoalescer, concat_batches, filter, filter_record_batch, not, take,
59};
60use arrow::datatypes::{Schema, SchemaRef};
61use arrow::record_batch::RecordBatch;
62use arrow_schema::DataType;
63use datafusion_common::cast::as_boolean_array;
64use datafusion_common::{
65 JoinSide, Result, ScalarValue, Statistics, arrow_err, assert_eq_or_internal_err,
66 internal_datafusion_err, internal_err, project_schema, unwrap_or_internal_err,
67};
68use datafusion_execution::TaskContext;
69use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
70use datafusion_expr::JoinType;
71use datafusion_physical_expr::equivalence::{
72 ProjectionMapping, join_equivalence_properties,
73};
74
75use datafusion_physical_expr::projection::{ProjectionRef, combine_projections};
76use futures::{Stream, StreamExt, TryStreamExt};
77use log::debug;
78use parking_lot::Mutex;
79
80#[expect(rustdoc::private_intra_doc_links)]
81#[derive(Debug)]
175pub struct NestedLoopJoinExec {
176 pub(crate) left: Arc<dyn ExecutionPlan>,
178 pub(crate) right: Arc<dyn ExecutionPlan>,
180 pub(crate) filter: Option<JoinFilter>,
182 pub(crate) join_type: JoinType,
184 join_schema: SchemaRef,
187 build_side_data: OnceAsync<JoinLeftData>,
194 column_indices: Vec<ColumnIndex>,
196 projection: Option<ProjectionRef>,
198
199 metrics: ExecutionPlanMetricsSet,
201 cache: Arc<PlanProperties>,
203}
204
205pub struct NestedLoopJoinExecBuilder {
207 left: Arc<dyn ExecutionPlan>,
208 right: Arc<dyn ExecutionPlan>,
209 join_type: JoinType,
210 filter: Option<JoinFilter>,
211 projection: Option<ProjectionRef>,
212}
213
214impl NestedLoopJoinExecBuilder {
215 pub fn new(
217 left: Arc<dyn ExecutionPlan>,
218 right: Arc<dyn ExecutionPlan>,
219 join_type: JoinType,
220 ) -> Self {
221 Self {
222 left,
223 right,
224 join_type,
225 filter: None,
226 projection: None,
227 }
228 }
229
230 pub fn with_projection(self, projection: Option<Vec<usize>>) -> Self {
232 self.with_projection_ref(projection.map(Into::into))
233 }
234
235 pub fn with_projection_ref(mut self, projection: Option<ProjectionRef>) -> Self {
237 self.projection = projection;
238 self
239 }
240
241 pub fn with_filter(mut self, filter: Option<JoinFilter>) -> Self {
243 self.filter = filter;
244 self
245 }
246
247 pub fn build(self) -> Result<NestedLoopJoinExec> {
249 let Self {
250 left,
251 right,
252 join_type,
253 filter,
254 projection,
255 } = self;
256
257 let left_schema = left.schema();
258 let right_schema = right.schema();
259 check_join_is_valid(&left_schema, &right_schema, &[])?;
260 let (join_schema, column_indices) =
261 build_join_schema(&left_schema, &right_schema, &join_type);
262 let join_schema = Arc::new(join_schema);
263 let cache = NestedLoopJoinExec::compute_properties(
264 &left,
265 &right,
266 &join_schema,
267 join_type,
268 projection.as_deref(),
269 )?;
270 Ok(NestedLoopJoinExec {
271 left,
272 right,
273 filter,
274 join_type,
275 join_schema,
276 build_side_data: Default::default(),
277 column_indices,
278 projection,
279 metrics: Default::default(),
280 cache: Arc::new(cache),
281 })
282 }
283}
284
285impl From<&NestedLoopJoinExec> for NestedLoopJoinExecBuilder {
286 fn from(exec: &NestedLoopJoinExec) -> Self {
287 Self {
288 left: Arc::clone(exec.left()),
289 right: Arc::clone(exec.right()),
290 join_type: exec.join_type,
291 filter: exec.filter.clone(),
292 projection: exec.projection.clone(),
293 }
294 }
295}
296
297impl NestedLoopJoinExec {
298 pub fn try_new(
300 left: Arc<dyn ExecutionPlan>,
301 right: Arc<dyn ExecutionPlan>,
302 filter: Option<JoinFilter>,
303 join_type: &JoinType,
304 projection: Option<Vec<usize>>,
305 ) -> Result<Self> {
306 NestedLoopJoinExecBuilder::new(left, right, *join_type)
307 .with_projection(projection)
308 .with_filter(filter)
309 .build()
310 }
311
312 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
314 &self.left
315 }
316
317 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
319 &self.right
320 }
321
322 pub fn filter(&self) -> Option<&JoinFilter> {
324 self.filter.as_ref()
325 }
326
327 pub fn join_type(&self) -> &JoinType {
329 &self.join_type
330 }
331
332 pub fn projection(&self) -> &Option<ProjectionRef> {
333 &self.projection
334 }
335
336 fn compute_properties(
338 left: &Arc<dyn ExecutionPlan>,
339 right: &Arc<dyn ExecutionPlan>,
340 schema: &SchemaRef,
341 join_type: JoinType,
342 projection: Option<&[usize]>,
343 ) -> Result<PlanProperties> {
344 let mut eq_properties = join_equivalence_properties(
346 left.equivalence_properties().clone(),
347 right.equivalence_properties().clone(),
348 &join_type,
349 Arc::clone(schema),
350 &Self::maintains_input_order(join_type),
351 None,
352 &[],
354 )?;
355
356 let mut output_partitioning =
357 asymmetric_join_output_partitioning(left, right, &join_type)?;
358
359 let emission_type = if left.boundedness().is_unbounded() {
360 EmissionType::Final
361 } else if right.pipeline_behavior() == EmissionType::Incremental {
362 match join_type {
363 JoinType::Inner
366 | JoinType::LeftSemi
367 | JoinType::RightSemi
368 | JoinType::Right
369 | JoinType::RightAnti
370 | JoinType::RightMark => EmissionType::Incremental,
371 JoinType::Left
374 | JoinType::LeftAnti
375 | JoinType::LeftMark
376 | JoinType::Full => EmissionType::Both,
377 }
378 } else {
379 right.pipeline_behavior()
380 };
381
382 if let Some(projection) = projection {
383 let projection_mapping = ProjectionMapping::from_indices(projection, schema)?;
385 let out_schema = project_schema(schema, Some(&projection))?;
386 output_partitioning =
387 output_partitioning.project(&projection_mapping, &eq_properties);
388 eq_properties = eq_properties.project(&projection_mapping, out_schema);
389 }
390
391 Ok(PlanProperties::new(
392 eq_properties,
393 output_partitioning,
394 emission_type,
395 boundedness_from_children([left, right]),
396 ))
397 }
398
399 fn maintains_input_order(_join_type: JoinType) -> Vec<bool> {
401 vec![false, false]
402 }
403
404 pub fn contains_projection(&self) -> bool {
405 self.projection.is_some()
406 }
407
408 pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
409 let projection = projection.map(Into::into);
410 can_project(&self.schema(), projection.as_deref())?;
412 let projection =
413 combine_projections(projection.as_ref(), self.projection.as_ref())?;
414 NestedLoopJoinExecBuilder::from(self)
415 .with_projection_ref(projection)
416 .build()
417 }
418
419 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
428 let left = self.left();
429 let right = self.right();
430 let new_join = NestedLoopJoinExec::try_new(
431 Arc::clone(right),
432 Arc::clone(left),
433 self.filter().map(JoinFilter::swap),
434 &self.join_type().swap(),
435 swap_join_projection(
436 left.schema().fields().len(),
437 right.schema().fields().len(),
438 self.projection.as_deref(),
439 self.join_type(),
440 ),
441 )?;
442
443 let plan: Arc<dyn ExecutionPlan> = if matches!(
446 self.join_type(),
447 JoinType::LeftSemi
448 | JoinType::RightSemi
449 | JoinType::LeftAnti
450 | JoinType::RightAnti
451 | JoinType::LeftMark
452 | JoinType::RightMark
453 ) || self.projection.is_some()
454 {
455 Arc::new(new_join)
456 } else {
457 reorder_output_after_swap(
458 Arc::new(new_join),
459 &self.left().schema(),
460 &self.right().schema(),
461 )?
462 };
463
464 Ok(plan)
465 }
466
467 fn with_new_children_and_same_properties(
468 &self,
469 mut children: Vec<Arc<dyn ExecutionPlan>>,
470 ) -> Self {
471 let left = children.swap_remove(0);
472 let right = children.swap_remove(0);
473
474 Self {
475 left,
476 right,
477 metrics: ExecutionPlanMetricsSet::new(),
478 build_side_data: Default::default(),
479 cache: Arc::clone(&self.cache),
480 filter: self.filter.clone(),
481 join_type: self.join_type,
482 join_schema: Arc::clone(&self.join_schema),
483 column_indices: self.column_indices.clone(),
484 projection: self.projection.clone(),
485 }
486 }
487}
488
489impl DisplayAs for NestedLoopJoinExec {
490 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
491 match t {
492 DisplayFormatType::Default | DisplayFormatType::Verbose => {
493 let display_filter = self.filter.as_ref().map_or_else(
494 || "".to_string(),
495 |f| format!(", filter={}", f.expression()),
496 );
497 let display_projections = if self.contains_projection() {
498 format!(
499 ", projection=[{}]",
500 self.projection
501 .as_ref()
502 .unwrap()
503 .iter()
504 .map(|index| format!(
505 "{}@{}",
506 self.join_schema.fields().get(*index).unwrap().name(),
507 index
508 ))
509 .collect::<Vec<_>>()
510 .join(", ")
511 )
512 } else {
513 "".to_string()
514 };
515 write!(
516 f,
517 "NestedLoopJoinExec: join_type={:?}{}{}",
518 self.join_type, display_filter, display_projections
519 )
520 }
521 DisplayFormatType::TreeRender => {
522 if *self.join_type() != JoinType::Inner {
523 writeln!(f, "join_type={:?}", self.join_type)
524 } else {
525 Ok(())
526 }
527 }
528 }
529 }
530}
531
532impl ExecutionPlan for NestedLoopJoinExec {
533 fn name(&self) -> &'static str {
534 "NestedLoopJoinExec"
535 }
536
537 fn as_any(&self) -> &dyn Any {
538 self
539 }
540
541 fn properties(&self) -> &Arc<PlanProperties> {
542 &self.cache
543 }
544
545 fn required_input_distribution(&self) -> Vec<Distribution> {
546 vec![
547 Distribution::SinglePartition,
548 Distribution::UnspecifiedDistribution,
549 ]
550 }
551
552 fn maintains_input_order(&self) -> Vec<bool> {
553 Self::maintains_input_order(self.join_type)
554 }
555
556 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
557 vec![&self.left, &self.right]
558 }
559
560 fn with_new_children(
561 self: Arc<Self>,
562 children: Vec<Arc<dyn ExecutionPlan>>,
563 ) -> Result<Arc<dyn ExecutionPlan>> {
564 check_if_same_properties!(self, children);
565 Ok(Arc::new(
566 NestedLoopJoinExecBuilder::new(
567 Arc::clone(&children[0]),
568 Arc::clone(&children[1]),
569 self.join_type,
570 )
571 .with_filter(self.filter.clone())
572 .with_projection_ref(self.projection.clone())
573 .build()?,
574 ))
575 }
576
577 fn execute(
578 &self,
579 partition: usize,
580 context: Arc<TaskContext>,
581 ) -> Result<SendableRecordBatchStream> {
582 assert_eq_or_internal_err!(
583 self.left.output_partitioning().partition_count(),
584 1,
585 "Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
586 consider using CoalescePartitionsExec or the EnforceDistribution rule"
587 );
588
589 let metrics = NestedLoopJoinMetrics::new(&self.metrics, partition);
590
591 let load_reservation =
593 MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
594 .register(context.memory_pool());
595
596 let build_side_data = self.build_side_data.try_once(|| {
597 let stream = self.left.execute(0, Arc::clone(&context))?;
598
599 Ok(collect_left_input(
600 stream,
601 metrics.join_metrics.clone(),
602 load_reservation,
603 need_produce_result_in_final(self.join_type),
604 self.right().output_partitioning().partition_count(),
605 ))
606 })?;
607
608 let batch_size = context.session_config().batch_size();
609
610 let probe_side_data = self.right.execute(partition, context)?;
611
612 let column_indices_after_projection = match self.projection.as_ref() {
614 Some(projection) => projection
615 .iter()
616 .map(|i| self.column_indices[*i].clone())
617 .collect(),
618 None => self.column_indices.clone(),
619 };
620
621 Ok(Box::pin(NestedLoopJoinStream::new(
622 self.schema(),
623 self.filter.clone(),
624 self.join_type,
625 probe_side_data,
626 build_side_data,
627 column_indices_after_projection,
628 metrics,
629 batch_size,
630 )))
631 }
632
633 fn metrics(&self) -> Option<MetricsSet> {
634 Some(self.metrics.clone_inner())
635 }
636
637 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
638 let join_columns = Vec::new();
646
647 let left_stats = self.left.partition_statistics(None)?;
652 let right_stats = match partition {
653 Some(partition) => self.right.partition_statistics(Some(partition))?,
654 None => self.right.partition_statistics(None)?,
655 };
656
657 let stats = estimate_join_statistics(
658 left_stats,
659 right_stats,
660 &join_columns,
661 &self.join_type,
662 &self.join_schema,
663 )?;
664
665 Ok(stats.project(self.projection.as_ref()))
666 }
667
668 fn try_swapping_with_projection(
672 &self,
673 projection: &ProjectionExec,
674 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
675 if self.contains_projection() {
677 return Ok(None);
678 }
679
680 let schema = self.schema();
681 if let Some(JoinData {
682 projected_left_child,
683 projected_right_child,
684 join_filter,
685 ..
686 }) = try_pushdown_through_join(
687 projection,
688 self.left(),
689 self.right(),
690 &[],
691 &schema,
692 self.filter(),
693 )? {
694 Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
695 Arc::new(projected_left_child),
696 Arc::new(projected_right_child),
697 join_filter,
698 self.join_type(),
699 None,
701 )?)))
702 } else {
703 try_embed_projection(projection, self)
704 }
705 }
706}
707
708impl EmbeddedProjection for NestedLoopJoinExec {
709 fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
710 self.with_projection(projection)
711 }
712}
713
714pub(crate) struct JoinLeftData {
716 batch: RecordBatch,
718 bitmap: SharedBitmapBuilder,
720 probe_threads_counter: AtomicUsize,
722 #[expect(dead_code)]
726 reservation: MemoryReservation,
727}
728
729impl JoinLeftData {
730 pub(crate) fn new(
731 batch: RecordBatch,
732 bitmap: SharedBitmapBuilder,
733 probe_threads_counter: AtomicUsize,
734 reservation: MemoryReservation,
735 ) -> Self {
736 Self {
737 batch,
738 bitmap,
739 probe_threads_counter,
740 reservation,
741 }
742 }
743
744 pub(crate) fn batch(&self) -> &RecordBatch {
745 &self.batch
746 }
747
748 pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder {
749 &self.bitmap
750 }
751
752 pub(crate) fn report_probe_completed(&self) -> bool {
755 self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
756 }
757}
758
759async fn collect_left_input(
761 stream: SendableRecordBatchStream,
762 join_metrics: BuildProbeJoinMetrics,
763 reservation: MemoryReservation,
764 with_visited_left_side: bool,
765 probe_threads_count: usize,
766) -> Result<JoinLeftData> {
767 let schema = stream.schema();
768
769 let (batches, metrics, reservation) = stream
771 .try_fold(
772 (Vec::new(), join_metrics, reservation),
773 |(mut batches, metrics, reservation), batch| async {
774 let batch_size = batch.get_array_memory_size();
775 reservation.try_grow(batch_size)?;
777 metrics.build_mem_used.add(batch_size);
779 metrics.build_input_batches.add(1);
780 metrics.build_input_rows.add(batch.num_rows());
781 batches.push(batch);
783 Ok((batches, metrics, reservation))
784 },
785 )
786 .await?;
787
788 let merged_batch = concat_batches(&schema, &batches)?;
789
790 let visited_left_side = if with_visited_left_side {
792 let n_rows = merged_batch.num_rows();
793 let buffer_size = n_rows.div_ceil(8);
794 reservation.try_grow(buffer_size)?;
795 metrics.build_mem_used.add(buffer_size);
796
797 let mut buffer = BooleanBufferBuilder::new(n_rows);
798 buffer.append_n(n_rows, false);
799 buffer
800 } else {
801 BooleanBufferBuilder::new(0)
802 };
803
804 Ok(JoinLeftData::new(
805 merged_batch,
806 Mutex::new(visited_left_side),
807 AtomicUsize::new(probe_threads_count),
808 reservation,
809 ))
810}
811
812#[derive(Debug, Clone, Copy)]
815enum NLJState {
816 BufferingLeft,
817 FetchingRight,
818 ProbeRight,
819 EmitRightUnmatched,
820 EmitLeftUnmatched,
821 Done,
822}
823pub(crate) struct NestedLoopJoinStream {
824 pub(crate) output_schema: Arc<Schema>,
835 pub(crate) join_filter: Option<JoinFilter>,
837 pub(crate) join_type: JoinType,
839 pub(crate) right_data: SendableRecordBatchStream,
841 pub(crate) left_data: OnceFut<JoinLeftData>,
843 pub(crate) column_indices: Vec<ColumnIndex>,
856 pub(crate) metrics: NestedLoopJoinMetrics,
858
859 batch_size: usize,
861
862 should_track_unmatched_right: bool,
864
865 state: NLJState,
871 output_buffer: Box<BatchCoalescer>,
874 handled_empty_output: bool,
876
877 buffered_left_data: Option<Arc<JoinLeftData>>,
881 left_probe_idx: usize,
883 left_emit_idx: usize,
885 left_exhausted: bool,
888 #[expect(dead_code)]
891 left_buffered_in_one_pass: bool,
892
893 current_right_batch: Option<RecordBatch>,
897 current_right_batch_matched: Option<BooleanArray>,
900}
901
902pub(crate) struct NestedLoopJoinMetrics {
903 pub(crate) join_metrics: BuildProbeJoinMetrics,
905 pub(crate) selectivity: RatioMetrics,
907}
908
909impl NestedLoopJoinMetrics {
910 pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
911 Self {
912 join_metrics: BuildProbeJoinMetrics::new(partition, metrics),
913 selectivity: MetricBuilder::new(metrics)
914 .with_type(MetricType::SUMMARY)
915 .ratio_metrics("selectivity", partition),
916 }
917 }
918}
919
920impl Stream for NestedLoopJoinStream {
921 type Item = Result<RecordBatch>;
922
923 fn poll_next(
954 mut self: std::pin::Pin<&mut Self>,
955 cx: &mut std::task::Context<'_>,
956 ) -> Poll<Option<Self::Item>> {
957 loop {
958 match self.state {
959 NLJState::BufferingLeft => {
965 debug!("[NLJState] Entering: {:?}", self.state);
966 let build_metric = self.metrics.join_metrics.build_time.clone();
971 let _build_timer = build_metric.timer();
972
973 match self.handle_buffering_left(cx) {
974 ControlFlow::Continue(()) => continue,
975 ControlFlow::Break(poll) => return poll,
976 }
977 }
978
979 NLJState::FetchingRight => {
1002 debug!("[NLJState] Entering: {:?}", self.state);
1003 let join_metric = self.metrics.join_metrics.join_time.clone();
1005 let _join_timer = join_metric.timer();
1006
1007 match self.handle_fetching_right(cx) {
1008 ControlFlow::Continue(()) => continue,
1009 ControlFlow::Break(poll) => return poll,
1010 }
1011 }
1012
1013 NLJState::ProbeRight => {
1028 debug!("[NLJState] Entering: {:?}", self.state);
1029
1030 let join_metric = self.metrics.join_metrics.join_time.clone();
1032 let _join_timer = join_metric.timer();
1033
1034 match self.handle_probe_right() {
1035 ControlFlow::Continue(()) => continue,
1036 ControlFlow::Break(poll) => {
1037 return self.metrics.join_metrics.baseline.record_poll(poll);
1038 }
1039 }
1040 }
1041
1042 NLJState::EmitRightUnmatched => {
1049 debug!("[NLJState] Entering: {:?}", self.state);
1050
1051 let join_metric = self.metrics.join_metrics.join_time.clone();
1053 let _join_timer = join_metric.timer();
1054
1055 match self.handle_emit_right_unmatched() {
1056 ControlFlow::Continue(()) => continue,
1057 ControlFlow::Break(poll) => {
1058 return self.metrics.join_metrics.baseline.record_poll(poll);
1059 }
1060 }
1061 }
1062
1063 NLJState::EmitLeftUnmatched => {
1079 debug!("[NLJState] Entering: {:?}", self.state);
1080
1081 let join_metric = self.metrics.join_metrics.join_time.clone();
1083 let _join_timer = join_metric.timer();
1084
1085 match self.handle_emit_left_unmatched() {
1086 ControlFlow::Continue(()) => continue,
1087 ControlFlow::Break(poll) => {
1088 return self.metrics.join_metrics.baseline.record_poll(poll);
1089 }
1090 }
1091 }
1092
1093 NLJState::Done => {
1095 debug!("[NLJState] Entering: {:?}", self.state);
1096
1097 let join_metric = self.metrics.join_metrics.join_time.clone();
1099 let _join_timer = join_metric.timer();
1100 let poll = self.handle_done();
1104 return self.metrics.join_metrics.baseline.record_poll(poll);
1105 }
1106 }
1107 }
1108 }
1109}
1110
1111impl RecordBatchStream for NestedLoopJoinStream {
1112 fn schema(&self) -> SchemaRef {
1113 Arc::clone(&self.output_schema)
1114 }
1115}
1116
1117impl NestedLoopJoinStream {
1118 #[expect(clippy::too_many_arguments)]
1119 pub(crate) fn new(
1120 schema: Arc<Schema>,
1121 filter: Option<JoinFilter>,
1122 join_type: JoinType,
1123 right_data: SendableRecordBatchStream,
1124 left_data: OnceFut<JoinLeftData>,
1125 column_indices: Vec<ColumnIndex>,
1126 metrics: NestedLoopJoinMetrics,
1127 batch_size: usize,
1128 ) -> Self {
1129 Self {
1130 output_schema: Arc::clone(&schema),
1131 join_filter: filter,
1132 join_type,
1133 right_data,
1134 column_indices,
1135 left_data,
1136 metrics,
1137 buffered_left_data: None,
1138 output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)),
1139 batch_size,
1140 current_right_batch: None,
1141 current_right_batch_matched: None,
1142 state: NLJState::BufferingLeft,
1143 left_probe_idx: 0,
1144 left_emit_idx: 0,
1145 left_exhausted: false,
1146 left_buffered_in_one_pass: true,
1147 handled_empty_output: false,
1148 should_track_unmatched_right: need_produce_right_in_final(join_type),
1149 }
1150 }
1151
1152 fn handle_buffering_left(
1156 &mut self,
1157 cx: &mut std::task::Context<'_>,
1158 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1159 match self.left_data.get_shared(cx) {
1160 Poll::Ready(Ok(left_data)) => {
1161 self.buffered_left_data = Some(left_data);
1162 self.left_exhausted = true;
1164 self.state = NLJState::FetchingRight;
1165 ControlFlow::Continue(())
1167 }
1168 Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1169 Poll::Pending => ControlFlow::Break(Poll::Pending),
1170 }
1171 }
1172
1173 fn handle_fetching_right(
1175 &mut self,
1176 cx: &mut std::task::Context<'_>,
1177 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1178 match self.right_data.poll_next_unpin(cx) {
1179 Poll::Ready(result) => match result {
1180 Some(Ok(right_batch)) => {
1181 let right_batch_size = right_batch.num_rows();
1183 self.metrics.join_metrics.input_rows.add(right_batch_size);
1184 self.metrics.join_metrics.input_batches.add(1);
1185
1186 if right_batch_size == 0 {
1188 return ControlFlow::Continue(());
1189 }
1190
1191 self.current_right_batch = Some(right_batch);
1192
1193 if self.should_track_unmatched_right {
1195 let zeroed_buf = BooleanBuffer::new_unset(right_batch_size);
1196 self.current_right_batch_matched =
1197 Some(BooleanArray::new(zeroed_buf, None));
1198 }
1199
1200 self.left_probe_idx = 0;
1201 self.state = NLJState::ProbeRight;
1202 ControlFlow::Continue(())
1203 }
1204 Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1205 None => {
1206 self.state = NLJState::EmitLeftUnmatched;
1208 ControlFlow::Continue(())
1209 }
1210 },
1211 Poll::Pending => ControlFlow::Break(Poll::Pending),
1212 }
1213 }
1214
1215 fn handle_probe_right(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1217 if let Some(poll) = self.maybe_flush_ready_batch() {
1219 return ControlFlow::Break(poll);
1220 }
1221
1222 match self.process_probe_batch() {
1224 Ok(true) => ControlFlow::Continue(()),
1228 Ok(false) => {
1232 self.left_probe_idx = 0;
1234
1235 if let (Ok(left_data), Some(right_batch)) =
1238 (self.get_left_data(), self.current_right_batch.as_ref())
1239 {
1240 let left_rows = left_data.batch().num_rows();
1241 let right_rows = right_batch.num_rows();
1242 self.metrics.selectivity.add_total(left_rows * right_rows);
1243 }
1244
1245 if self.should_track_unmatched_right {
1246 debug_assert!(
1247 self.current_right_batch_matched.is_some(),
1248 "If it's required to track matched rows in the right input, the right bitmap must be present"
1249 );
1250 self.state = NLJState::EmitRightUnmatched;
1251 } else {
1252 self.current_right_batch = None;
1253 self.state = NLJState::FetchingRight;
1254 }
1255 ControlFlow::Continue(())
1256 }
1257 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1258 }
1259 }
1260
1261 fn handle_emit_right_unmatched(
1263 &mut self,
1264 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1265 if let Some(poll) = self.maybe_flush_ready_batch() {
1267 return ControlFlow::Break(poll);
1268 }
1269
1270 debug_assert!(
1271 self.current_right_batch_matched.is_some()
1272 && self.current_right_batch.is_some(),
1273 "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present"
1274 );
1275 match self.process_right_unmatched() {
1277 Ok(Some(batch)) => {
1278 match self.output_buffer.push_batch(batch) {
1279 Ok(()) => {
1280 debug_assert!(self.current_right_batch.is_none());
1283 self.state = NLJState::FetchingRight;
1284 ControlFlow::Continue(())
1285 }
1286 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1287 }
1288 }
1289 Ok(None) => {
1290 debug_assert!(self.current_right_batch.is_none());
1293 self.state = NLJState::FetchingRight;
1294 ControlFlow::Continue(())
1295 }
1296 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1297 }
1298 }
1299
1300 fn handle_emit_left_unmatched(
1302 &mut self,
1303 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1304 if let Some(poll) = self.maybe_flush_ready_batch() {
1306 return ControlFlow::Break(poll);
1307 }
1308
1309 match self.process_left_unmatched() {
1311 Ok(true) => ControlFlow::Continue(()),
1314 Ok(false) => match self.output_buffer.finish_buffered_batch() {
1317 Ok(()) => {
1318 self.state = NLJState::Done;
1319 ControlFlow::Continue(())
1320 }
1321 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1322 },
1323 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1324 }
1325 }
1326
1327 fn handle_done(&mut self) -> Poll<Option<Result<RecordBatch>>> {
1329 if let Some(poll) = self.maybe_flush_ready_batch() {
1331 return poll;
1332 }
1333
1334 if !self.handled_empty_output {
1340 let zero_count = Count::new();
1341 if *self.metrics.join_metrics.baseline.output_rows() == zero_count {
1342 let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema));
1343 self.handled_empty_output = true;
1344 return Poll::Ready(Some(Ok(empty_batch)));
1345 }
1346 }
1347
1348 Poll::Ready(None)
1349 }
1350
1351 fn process_probe_batch(&mut self) -> Result<bool> {
1358 let left_data = Arc::clone(self.get_left_data()?);
1359 let right_batch = self
1360 .current_right_batch
1361 .as_ref()
1362 .ok_or_else(|| internal_datafusion_err!("Right batch should be available"))?
1363 .clone();
1364
1365 if self.left_probe_idx >= left_data.batch().num_rows() {
1367 return Ok(false);
1368 }
1369
1370 debug_assert_ne!(
1383 right_batch.num_rows(),
1384 0,
1385 "When fetching the right batch, empty batches will be skipped"
1386 );
1387
1388 let l_row_cnt_ratio = self.batch_size / right_batch.num_rows();
1389 if l_row_cnt_ratio > 10 {
1390 let l_row_count = std::cmp::min(
1394 l_row_cnt_ratio,
1395 left_data.batch().num_rows() - self.left_probe_idx,
1396 );
1397
1398 debug_assert!(
1399 l_row_count != 0,
1400 "This function should only be entered when there are remaining left rows to process"
1401 );
1402 let joined_batch = self.process_left_range_join(
1403 &left_data,
1404 &right_batch,
1405 self.left_probe_idx,
1406 l_row_count,
1407 )?;
1408
1409 if let Some(batch) = joined_batch {
1410 self.output_buffer.push_batch(batch)?;
1411 }
1412
1413 self.left_probe_idx += l_row_count;
1414
1415 return Ok(true);
1416 }
1417
1418 let l_idx = self.left_probe_idx;
1419 let joined_batch =
1420 self.process_single_left_row_join(&left_data, &right_batch, l_idx)?;
1421
1422 if let Some(batch) = joined_batch {
1423 self.output_buffer.push_batch(batch)?;
1424 }
1425
1426 self.left_probe_idx += 1;
1430
1431 Ok(true)
1433 }
1434
1435 fn process_left_range_join(
1441 &mut self,
1442 left_data: &JoinLeftData,
1443 right_batch: &RecordBatch,
1444 l_start_index: usize,
1445 l_row_count: usize,
1446 ) -> Result<Option<RecordBatch>> {
1447 let right_rows = right_batch.num_rows();
1453 let total_rows = l_row_count * right_rows;
1454
1455 let left_indices: UInt32Array =
1457 UInt32Array::from_iter_values((0..l_row_count).flat_map(|i| {
1458 std::iter::repeat_n((l_start_index + i) as u32, right_rows)
1459 }));
1460 let right_indices: UInt32Array = UInt32Array::from_iter_values(
1461 (0..l_row_count).flat_map(|_| 0..right_rows as u32),
1462 );
1463
1464 debug_assert!(
1465 left_indices.len() == right_indices.len()
1466 && right_indices.len() == total_rows,
1467 "The length or cartesian product should be (left_size * right_size)",
1468 );
1469
1470 let bitmap_combined = if let Some(filter) = &self.join_filter {
1473 let intermediate_batch = if filter.schema.fields().is_empty() {
1475 create_record_batch_with_empty_schema(
1477 Arc::new((*filter.schema).clone()),
1478 total_rows,
1479 )?
1480 } else {
1481 let mut filter_columns: Vec<Arc<dyn Array>> =
1482 Vec::with_capacity(filter.column_indices().len());
1483 for column_index in filter.column_indices() {
1484 let array = if column_index.side == JoinSide::Left {
1485 let col = left_data.batch().column(column_index.index);
1486 take(col.as_ref(), &left_indices, None)?
1487 } else {
1488 let col = right_batch.column(column_index.index);
1489 take(col.as_ref(), &right_indices, None)?
1490 };
1491 filter_columns.push(array);
1492 }
1493
1494 RecordBatch::try_new(Arc::new((*filter.schema).clone()), filter_columns)?
1495 };
1496
1497 let filter_result = filter
1498 .expression()
1499 .evaluate(&intermediate_batch)?
1500 .into_array(intermediate_batch.num_rows())?;
1501 let filter_arr = as_boolean_array(&filter_result)?;
1502
1503 boolean_mask_from_filter(filter_arr)
1505 } else {
1506 BooleanArray::from(vec![true; total_rows])
1508 };
1509
1510 let mut left_bitmap = if need_produce_result_in_final(self.join_type) {
1515 Some(left_data.bitmap().lock())
1516 } else {
1517 None
1518 };
1519
1520 let mut local_right_bitmap = if self.should_track_unmatched_right {
1524 let mut current_right_batch_bitmap = BooleanBufferBuilder::new(right_rows);
1525 current_right_batch_bitmap.append_n(right_rows, false);
1527 Some(current_right_batch_bitmap)
1528 } else {
1529 None
1530 };
1531
1532 for (i, is_matched) in bitmap_combined.iter().enumerate() {
1534 let is_matched = is_matched.ok_or_else(|| {
1535 internal_datafusion_err!("Must be Some after the previous combining step")
1536 })?;
1537
1538 let l_index = l_start_index + i / right_rows;
1539 let r_index = i % right_rows;
1540
1541 if let Some(bitmap) = left_bitmap.as_mut()
1542 && is_matched
1543 {
1544 bitmap.set_bit(l_index, true);
1546 }
1547
1548 if let Some(bitmap) = local_right_bitmap.as_mut()
1549 && is_matched
1550 {
1551 bitmap.set_bit(r_index, true);
1552 }
1553 }
1554
1555 if self.should_track_unmatched_right {
1557 let global_right_bitmap =
1559 std::mem::take(&mut self.current_right_batch_matched).ok_or_else(
1560 || internal_datafusion_err!("right batch's bitmap should be present"),
1561 )?;
1562 let (buf, nulls) = global_right_bitmap.into_parts();
1563 debug_assert!(nulls.is_none());
1564
1565 let current_right_bitmap = local_right_bitmap
1566 .ok_or_else(|| {
1567 internal_datafusion_err!(
1568 "Should be Some if the current join type requires right bitmap"
1569 )
1570 })?
1571 .finish();
1572 let updated_global_right_bitmap = buf.bitor(¤t_right_bitmap);
1573
1574 self.current_right_batch_matched =
1575 Some(BooleanArray::new(updated_global_right_bitmap, None));
1576 }
1577
1578 if matches!(
1580 self.join_type,
1581 JoinType::LeftAnti
1582 | JoinType::LeftSemi
1583 | JoinType::LeftMark
1584 | JoinType::RightAnti
1585 | JoinType::RightMark
1586 | JoinType::RightSemi
1587 ) {
1588 return Ok(None);
1589 }
1590
1591 if self.output_schema.fields().is_empty() {
1594 let row_count = bitmap_combined.true_count();
1596 return Ok(Some(create_record_batch_with_empty_schema(
1597 Arc::clone(&self.output_schema),
1598 row_count,
1599 )?));
1600 }
1601
1602 let mut out_columns: Vec<Arc<dyn Array>> =
1603 Vec::with_capacity(self.output_schema.fields().len());
1604 for column_index in &self.column_indices {
1605 let array = if column_index.side == JoinSide::Left {
1606 let col = left_data.batch().column(column_index.index);
1607 take(col.as_ref(), &left_indices, None)?
1608 } else {
1609 let col = right_batch.column(column_index.index);
1610 take(col.as_ref(), &right_indices, None)?
1611 };
1612 out_columns.push(array);
1613 }
1614 let pre_filtered =
1615 RecordBatch::try_new(Arc::clone(&self.output_schema), out_columns)?;
1616 let filtered = filter_record_batch(&pre_filtered, &bitmap_combined)?;
1617 Ok(Some(filtered))
1618 }
1619
1620 fn process_single_left_row_join(
1626 &mut self,
1627 left_data: &JoinLeftData,
1628 right_batch: &RecordBatch,
1629 l_index: usize,
1630 ) -> Result<Option<RecordBatch>> {
1631 let right_row_count = right_batch.num_rows();
1632 if right_row_count == 0 {
1633 return Ok(None);
1634 }
1635
1636 let cur_right_bitmap = if let Some(filter) = &self.join_filter {
1637 apply_filter_to_row_join_batch(
1638 left_data.batch(),
1639 l_index,
1640 right_batch,
1641 filter,
1642 )?
1643 } else {
1644 BooleanArray::from(vec![true; right_row_count])
1645 };
1646
1647 self.update_matched_bitmap(l_index, &cur_right_bitmap)?;
1648
1649 if matches!(
1652 self.join_type,
1653 JoinType::LeftAnti
1654 | JoinType::LeftSemi
1655 | JoinType::LeftMark
1656 | JoinType::RightAnti
1657 | JoinType::RightMark
1658 | JoinType::RightSemi
1659 ) {
1660 return Ok(None);
1661 }
1662
1663 if cur_right_bitmap.true_count() == 0 {
1664 Ok(None)
1666 } else {
1667 let join_batch = build_row_join_batch(
1669 &self.output_schema,
1670 left_data.batch(),
1671 l_index,
1672 right_batch,
1673 Some(cur_right_bitmap),
1674 &self.column_indices,
1675 JoinSide::Left,
1676 )?;
1677 Ok(join_batch)
1678 }
1679 }
1680
1681 fn process_left_unmatched(&mut self) -> Result<bool> {
1685 let left_data = self.get_left_data()?;
1686 let left_batch = left_data.batch();
1687
1688 let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
1694 let handled_by_other_partition =
1696 self.left_emit_idx == 0 && !left_data.report_probe_completed();
1697 let finished = self.left_emit_idx >= left_batch.num_rows();
1699
1700 if join_type_no_produce_left || handled_by_other_partition || finished {
1701 return Ok(false);
1702 }
1703
1704 let start_idx = self.left_emit_idx;
1709 let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
1710
1711 if let Some(batch) =
1712 self.process_left_unmatched_range(left_data, start_idx, end_idx)?
1713 {
1714 self.output_buffer.push_batch(batch)?;
1715 }
1716
1717 self.left_emit_idx = end_idx;
1719
1720 Ok(true)
1722 }
1723
1724 fn process_left_unmatched_range(
1737 &self,
1738 left_data: &JoinLeftData,
1739 start_idx: usize,
1740 end_idx: usize,
1741 ) -> Result<Option<RecordBatch>> {
1742 if start_idx == end_idx {
1743 return Ok(None);
1744 }
1745
1746 let left_batch = left_data.batch();
1749 let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx);
1750
1751 let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx);
1753 bitmap_sliced.append_n(end_idx - start_idx, false);
1754 let bitmap = left_data.bitmap().lock();
1755 for i in start_idx..end_idx {
1756 assert!(
1757 i - start_idx < bitmap_sliced.capacity(),
1758 "DBG: {start_idx}, {end_idx}"
1759 );
1760 bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i));
1761 }
1762 let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None);
1763
1764 let right_schema = self.right_data.schema();
1765 build_unmatched_batch(
1766 &self.output_schema,
1767 &left_batch_sliced,
1768 bitmap_sliced,
1769 &right_schema,
1770 &self.column_indices,
1771 self.join_type,
1772 JoinSide::Left,
1773 )
1774 }
1775
1776 fn process_right_unmatched(&mut self) -> Result<Option<RecordBatch>> {
1779 let right_batch_bitmap: BooleanArray =
1781 std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| {
1782 internal_datafusion_err!("right bitmap should be available")
1783 })?;
1784
1785 let right_batch = self.current_right_batch.take();
1786 let cur_right_batch = unwrap_or_internal_err!(right_batch);
1787
1788 let left_data = self.get_left_data()?;
1789 let left_schema = left_data.batch().schema();
1790
1791 let res = build_unmatched_batch(
1792 &self.output_schema,
1793 &cur_right_batch,
1794 right_batch_bitmap,
1795 &left_schema,
1796 &self.column_indices,
1797 self.join_type,
1798 JoinSide::Right,
1799 );
1800
1801 self.current_right_batch_matched = None;
1803
1804 res
1805 }
1806
1807 fn get_left_data(&self) -> Result<&Arc<JoinLeftData>> {
1811 self.buffered_left_data
1812 .as_ref()
1813 .ok_or_else(|| internal_datafusion_err!("LeftData should be available"))
1814 }
1815
1816 fn maybe_flush_ready_batch(&mut self) -> Option<Poll<Option<Result<RecordBatch>>>> {
1819 if self.output_buffer.has_completed_batch()
1820 && let Some(batch) = self.output_buffer.next_completed_batch()
1821 {
1822 let output_rows = batch.num_rows();
1824 self.metrics.selectivity.add_part(output_rows);
1825
1826 return Some(Poll::Ready(Some(Ok(batch))));
1827 }
1828
1829 None
1830 }
1831
1832 fn update_matched_bitmap(
1848 &mut self,
1849 l_index: usize,
1850 r_matched_bitmap: &BooleanArray,
1851 ) -> Result<()> {
1852 let left_data = self.get_left_data()?;
1853
1854 let joined_len = r_matched_bitmap.true_count();
1856
1857 if need_produce_result_in_final(self.join_type) && (joined_len > 0) {
1859 let mut bitmap = left_data.bitmap().lock();
1860 bitmap.set_bit(l_index, true);
1861 }
1862
1863 if self.should_track_unmatched_right {
1865 debug_assert!(self.current_right_batch_matched.is_some());
1866 let right_bitmap = std::mem::take(&mut self.current_right_batch_matched)
1868 .ok_or_else(|| {
1869 internal_datafusion_err!("right batch's bitmap should be present")
1870 })?;
1871 let (buf, nulls) = right_bitmap.into_parts();
1872 debug_assert!(nulls.is_none());
1873 let updated_right_bitmap = buf.bitor(r_matched_bitmap.values());
1874
1875 self.current_right_batch_matched =
1876 Some(BooleanArray::new(updated_right_bitmap, None));
1877 }
1878
1879 Ok(())
1880 }
1881}
1882
1883fn apply_filter_to_row_join_batch(
1889 left_batch: &RecordBatch,
1890 l_index: usize,
1891 right_batch: &RecordBatch,
1892 filter: &JoinFilter,
1893) -> Result<BooleanArray> {
1894 debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0);
1895
1896 let intermediate_batch = if filter.schema.fields().is_empty() {
1897 create_record_batch_with_empty_schema(
1900 Arc::new((*filter.schema).clone()),
1901 right_batch.num_rows(),
1902 )?
1903 } else {
1904 build_row_join_batch(
1905 &filter.schema,
1906 left_batch,
1907 l_index,
1908 right_batch,
1909 None,
1910 &filter.column_indices,
1911 JoinSide::Left,
1912 )?
1913 .ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))?
1914 };
1915
1916 let filter_result = filter
1917 .expression()
1918 .evaluate(&intermediate_batch)?
1919 .into_array(intermediate_batch.num_rows())?;
1920 let filter_arr = as_boolean_array(&filter_result)?;
1921
1922 let bitmap_combined = boolean_mask_from_filter(filter_arr);
1924
1925 Ok(bitmap_combined)
1926}
1927
1928#[inline]
1934fn boolean_mask_from_filter(filter_arr: &BooleanArray) -> BooleanArray {
1935 let (values, nulls) = filter_arr.clone().into_parts();
1936 match nulls {
1937 Some(nulls) => BooleanArray::new(nulls.inner() & &values, None),
1938 None => BooleanArray::new(values, None),
1939 }
1940}
1941
1942fn build_row_join_batch(
1990 output_schema: &Schema,
1991 build_side_batch: &RecordBatch,
1992 build_side_index: usize,
1993 probe_side_batch: &RecordBatch,
1994 probe_side_filter: Option<BooleanArray>,
1995 col_indices: &[ColumnIndex],
1997 build_side: JoinSide,
2000) -> Result<Option<RecordBatch>> {
2001 debug_assert!(build_side != JoinSide::None);
2002
2003 let filtered_probe_batch = if let Some(filter) = probe_side_filter {
2006 &filter_record_batch(probe_side_batch, &filter)?
2007 } else {
2008 probe_side_batch
2009 };
2010
2011 if filtered_probe_batch.num_rows() == 0 {
2012 return Ok(None);
2013 }
2014
2015 if output_schema.fields.is_empty() {
2023 return Ok(Some(create_record_batch_with_empty_schema(
2024 Arc::new(output_schema.clone()),
2025 filtered_probe_batch.num_rows(),
2026 )?));
2027 }
2028
2029 let mut columns: Vec<Arc<dyn Array>> =
2030 Vec::with_capacity(output_schema.fields().len());
2031
2032 for column_index in col_indices {
2033 let array = if column_index.side == build_side {
2034 let original_left_array = build_side_batch.column(column_index.index);
2037
2038 match original_left_array.data_type() {
2044 DataType::List(field) | DataType::LargeList(field)
2045 if field.data_type() == &DataType::Utf8View =>
2046 {
2047 let indices_iter = std::iter::repeat_n(
2048 build_side_index as u64,
2049 filtered_probe_batch.num_rows(),
2050 );
2051 let indices_array = UInt64Array::from_iter_values(indices_iter);
2052 take(original_left_array.as_ref(), &indices_array, None)?
2053 }
2054 _ => {
2055 let scalar_value = ScalarValue::try_from_array(
2056 original_left_array.as_ref(),
2057 build_side_index,
2058 )?;
2059 scalar_value.to_array_of_size(filtered_probe_batch.num_rows())?
2060 }
2061 }
2062 } else {
2063 Arc::clone(filtered_probe_batch.column(column_index.index))
2065 };
2066
2067 columns.push(array);
2068 }
2069
2070 Ok(Some(RecordBatch::try_new(
2071 Arc::new(output_schema.clone()),
2072 columns,
2073 )?))
2074}
2075
2076fn build_unmatched_batch_empty_schema(
2083 output_schema: &SchemaRef,
2084 batch_bitmap: &BooleanArray,
2085 join_type: JoinType,
2087) -> Result<Option<RecordBatch>> {
2088 let result_size = match join_type {
2089 JoinType::Left
2090 | JoinType::Right
2091 | JoinType::Full
2092 | JoinType::LeftAnti
2093 | JoinType::RightAnti => batch_bitmap.false_count(),
2094 JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(),
2095 JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(),
2096 _ => unreachable!(),
2097 };
2098
2099 if output_schema.fields().is_empty() {
2100 Ok(Some(create_record_batch_with_empty_schema(
2101 Arc::clone(output_schema),
2102 result_size,
2103 )?))
2104 } else {
2105 Ok(None)
2106 }
2107}
2108
2109fn create_record_batch_with_empty_schema(
2113 schema: SchemaRef,
2114 row_count: usize,
2115) -> Result<RecordBatch> {
2116 let options = RecordBatchOptions::new()
2117 .with_match_field_names(true)
2118 .with_row_count(Some(row_count));
2119
2120 RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| {
2121 internal_datafusion_err!("Failed to create empty record batch: {}", e)
2122 })
2123}
2124
2125fn build_unmatched_batch(
2161 output_schema: &SchemaRef,
2162 batch: &RecordBatch,
2163 batch_bitmap: BooleanArray,
2164 another_side_schema: &SchemaRef,
2166 col_indices: &[ColumnIndex],
2167 join_type: JoinType,
2168 batch_side: JoinSide,
2169) -> Result<Option<RecordBatch>> {
2170 debug_assert_ne!(join_type, JoinType::Inner);
2172 debug_assert_ne!(batch_side, JoinSide::None);
2173
2174 if let Some(batch) =
2176 build_unmatched_batch_empty_schema(output_schema, &batch_bitmap, join_type)?
2177 {
2178 return Ok(Some(batch));
2179 }
2180
2181 match join_type {
2182 JoinType::Full | JoinType::Right | JoinType::Left => {
2183 if join_type == JoinType::Right {
2184 debug_assert_eq!(batch_side, JoinSide::Right);
2185 }
2186 if join_type == JoinType::Left {
2187 debug_assert_eq!(batch_side, JoinSide::Left);
2188 }
2189
2190 let flipped_bitmap = not(&batch_bitmap)?;
2193
2194 let left_null_columns: Vec<Arc<dyn Array>> = another_side_schema
2196 .fields()
2197 .iter()
2198 .map(|field| new_null_array(field.data_type(), 1))
2199 .collect();
2200
2201 let nullable_left_schema = Arc::new(Schema::new(
2205 another_side_schema
2206 .fields()
2207 .iter()
2208 .map(|field| (**field).clone().with_nullable(true))
2209 .collect::<Vec<_>>(),
2210 ));
2211 let left_null_batch = if nullable_left_schema.fields.is_empty() {
2212 create_record_batch_with_empty_schema(nullable_left_schema, 0)?
2215 } else {
2216 RecordBatch::try_new(nullable_left_schema, left_null_columns)?
2217 };
2218
2219 debug_assert_ne!(batch_side, JoinSide::None);
2220 let opposite_side = batch_side.negate();
2221
2222 build_row_join_batch(
2223 output_schema,
2224 &left_null_batch,
2225 0,
2226 batch,
2227 Some(flipped_bitmap),
2228 col_indices,
2229 opposite_side,
2230 )
2231 }
2232 JoinType::RightSemi
2233 | JoinType::RightAnti
2234 | JoinType::LeftSemi
2235 | JoinType::LeftAnti => {
2236 if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) {
2237 debug_assert_eq!(batch_side, JoinSide::Right);
2238 }
2239 if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
2240 debug_assert_eq!(batch_side, JoinSide::Left);
2241 }
2242
2243 let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi)
2244 {
2245 batch_bitmap.clone()
2246 } else {
2247 not(&batch_bitmap)?
2248 };
2249
2250 if bitmap.true_count() == 0 {
2251 return Ok(None);
2252 }
2253
2254 let mut columns: Vec<Arc<dyn Array>> =
2255 Vec::with_capacity(output_schema.fields().len());
2256
2257 for column_index in col_indices {
2258 debug_assert!(column_index.side == batch_side);
2259
2260 let col = batch.column(column_index.index);
2261 let filtered_col = filter(col, &bitmap)?;
2262
2263 columns.push(filtered_col);
2264 }
2265
2266 Ok(Some(RecordBatch::try_new(
2267 Arc::clone(output_schema),
2268 columns,
2269 )?))
2270 }
2271 JoinType::RightMark | JoinType::LeftMark => {
2272 if join_type == JoinType::RightMark {
2273 debug_assert_eq!(batch_side, JoinSide::Right);
2274 }
2275 if join_type == JoinType::LeftMark {
2276 debug_assert_eq!(batch_side, JoinSide::Left);
2277 }
2278
2279 let mut columns: Vec<Arc<dyn Array>> =
2280 Vec::with_capacity(output_schema.fields().len());
2281
2282 let mut right_batch_bitmap_opt = Some(batch_bitmap);
2284
2285 for column_index in col_indices {
2286 if column_index.side == batch_side {
2287 let col = batch.column(column_index.index);
2288
2289 columns.push(Arc::clone(col));
2290 } else if column_index.side == JoinSide::None {
2291 let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt);
2292 match right_batch_bitmap {
2293 Some(right_batch_bitmap) => {
2294 columns.push(Arc::new(right_batch_bitmap))
2295 }
2296 None => unreachable!("Should only be one mark column"),
2297 }
2298 } else {
2299 return internal_err!(
2300 "Not possible to have this join side for RightMark join"
2301 );
2302 }
2303 }
2304
2305 Ok(Some(RecordBatch::try_new(
2306 Arc::clone(output_schema),
2307 columns,
2308 )?))
2309 }
2310 _ => internal_err!(
2311 "If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"
2312 ),
2313 }
2314}
2315
2316#[cfg(test)]
2317pub(crate) mod tests {
2318 use super::*;
2319 use crate::test::{TestMemoryExec, assert_join_metrics};
2320 use crate::{
2321 common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
2322 };
2323
2324 use arrow::compute::SortOptions;
2325 use arrow::datatypes::{DataType, Field};
2326 use datafusion_common::test_util::batches_to_sort_string;
2327 use datafusion_common::{ScalarValue, assert_contains};
2328 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
2329 use datafusion_expr::Operator;
2330 use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
2331 use datafusion_physical_expr::{Partitioning, PhysicalExpr};
2332 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2333
2334 use insta::allow_duplicates;
2335 use insta::assert_snapshot;
2336 use rstest::rstest;
2337
2338 fn build_table(
2339 a: (&str, &Vec<i32>),
2340 b: (&str, &Vec<i32>),
2341 c: (&str, &Vec<i32>),
2342 batch_size: Option<usize>,
2343 sorted_column_names: Vec<&str>,
2344 ) -> Arc<dyn ExecutionPlan> {
2345 let batch = build_table_i32(a, b, c);
2346 let schema = batch.schema();
2347
2348 let batches = if let Some(batch_size) = batch_size {
2349 let num_batches = batch.num_rows().div_ceil(batch_size);
2350 (0..num_batches)
2351 .map(|i| {
2352 let start = i * batch_size;
2353 let remaining_rows = batch.num_rows() - start;
2354 batch.slice(start, batch_size.min(remaining_rows))
2355 })
2356 .collect::<Vec<_>>()
2357 } else {
2358 vec![batch]
2359 };
2360
2361 let mut sort_info = vec![];
2362 for name in sorted_column_names {
2363 let index = schema.index_of(name).unwrap();
2364 let sort_expr = PhysicalSortExpr::new(
2365 Arc::new(Column::new(name, index)),
2366 SortOptions::new(false, false),
2367 );
2368 sort_info.push(sort_expr);
2369 }
2370 let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap();
2371 if let Some(ordering) = LexOrdering::new(sort_info) {
2372 source = source.try_with_sort_information(vec![ordering]).unwrap();
2373 }
2374
2375 let source = Arc::new(source);
2376 Arc::new(TestMemoryExec::update_cache(&source))
2377 }
2378
2379 fn build_left_table() -> Arc<dyn ExecutionPlan> {
2380 build_table(
2381 ("a1", &vec![5, 9, 11]),
2382 ("b1", &vec![5, 8, 8]),
2383 ("c1", &vec![50, 90, 110]),
2384 None,
2385 Vec::new(),
2386 )
2387 }
2388
2389 fn build_right_table() -> Arc<dyn ExecutionPlan> {
2390 build_table(
2391 ("a2", &vec![12, 2, 10]),
2392 ("b2", &vec![10, 2, 10]),
2393 ("c2", &vec![40, 80, 100]),
2394 None,
2395 Vec::new(),
2396 )
2397 }
2398
2399 fn prepare_join_filter() -> JoinFilter {
2400 let column_indices = vec![
2401 ColumnIndex {
2402 index: 1,
2403 side: JoinSide::Left,
2404 },
2405 ColumnIndex {
2406 index: 1,
2407 side: JoinSide::Right,
2408 },
2409 ];
2410 let intermediate_schema = Schema::new(vec![
2411 Field::new("x", DataType::Int32, true),
2412 Field::new("x", DataType::Int32, true),
2413 ]);
2414 let left_filter = Arc::new(BinaryExpr::new(
2416 Arc::new(Column::new("x", 0)),
2417 Operator::NotEq,
2418 Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
2419 )) as Arc<dyn PhysicalExpr>;
2420 let right_filter = Arc::new(BinaryExpr::new(
2422 Arc::new(Column::new("x", 1)),
2423 Operator::NotEq,
2424 Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2425 )) as Arc<dyn PhysicalExpr>;
2426 let filter_expression =
2437 Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
2438 as Arc<dyn PhysicalExpr>;
2439
2440 JoinFilter::new(
2441 filter_expression,
2442 column_indices,
2443 Arc::new(intermediate_schema),
2444 )
2445 }
2446
2447 pub(crate) async fn multi_partitioned_join_collect(
2448 left: Arc<dyn ExecutionPlan>,
2449 right: Arc<dyn ExecutionPlan>,
2450 join_type: &JoinType,
2451 join_filter: Option<JoinFilter>,
2452 context: Arc<TaskContext>,
2453 ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
2454 let partition_count = 4;
2455
2456 let right = Arc::new(RepartitionExec::try_new(
2458 right,
2459 Partitioning::RoundRobinBatch(partition_count),
2460 )?) as Arc<dyn ExecutionPlan>;
2461
2462 let nested_loop_join =
2464 NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
2465 let columns = columns(&nested_loop_join.schema());
2466 let mut batches = vec![];
2467 for i in 0..partition_count {
2468 let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
2469 let more_batches = common::collect(stream).await?;
2470 batches.extend(
2471 more_batches
2472 .into_iter()
2473 .inspect(|b| {
2474 assert!(b.num_rows() <= context.session_config().batch_size())
2475 })
2476 .filter(|b| b.num_rows() > 0)
2477 .collect::<Vec<_>>(),
2478 );
2479 }
2480
2481 let metrics = nested_loop_join.metrics().unwrap();
2482
2483 Ok((columns, batches, metrics))
2484 }
2485
2486 fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
2487 let base = TaskContext::default();
2488 let cfg = base.session_config().clone().with_batch_size(batch_size);
2490 Arc::new(base.with_session_config(cfg))
2491 }
2492
2493 #[rstest]
2494 #[tokio::test]
2495 async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2496 let task_ctx = new_task_ctx(batch_size);
2497 dbg!(&batch_size);
2498 let left = build_left_table();
2499 let right = build_right_table();
2500 let filter = prepare_join_filter();
2501 let (columns, batches, metrics) = multi_partitioned_join_collect(
2502 left,
2503 right,
2504 &JoinType::Inner,
2505 Some(filter),
2506 task_ctx,
2507 )
2508 .await?;
2509
2510 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2511 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2512 +----+----+----+----+----+----+
2513 | a1 | b1 | c1 | a2 | b2 | c2 |
2514 +----+----+----+----+----+----+
2515 | 5 | 5 | 50 | 2 | 2 | 80 |
2516 +----+----+----+----+----+----+
2517 "));
2518
2519 assert_join_metrics!(metrics, 1);
2520
2521 Ok(())
2522 }
2523
2524 #[rstest]
2525 #[tokio::test]
2526 async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2527 let task_ctx = new_task_ctx(batch_size);
2528 let left = build_left_table();
2529 let right = build_right_table();
2530
2531 let filter = prepare_join_filter();
2532 let (columns, batches, metrics) = multi_partitioned_join_collect(
2533 left,
2534 right,
2535 &JoinType::Left,
2536 Some(filter),
2537 task_ctx,
2538 )
2539 .await?;
2540 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2541 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2542 +----+----+-----+----+----+----+
2543 | a1 | b1 | c1 | a2 | b2 | c2 |
2544 +----+----+-----+----+----+----+
2545 | 11 | 8 | 110 | | | |
2546 | 5 | 5 | 50 | 2 | 2 | 80 |
2547 | 9 | 8 | 90 | | | |
2548 +----+----+-----+----+----+----+
2549 "));
2550
2551 assert_join_metrics!(metrics, 3);
2552
2553 Ok(())
2554 }
2555
2556 #[rstest]
2557 #[tokio::test]
2558 async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2559 let task_ctx = new_task_ctx(batch_size);
2560 let left = build_left_table();
2561 let right = build_right_table();
2562
2563 let filter = prepare_join_filter();
2564 let (columns, batches, metrics) = multi_partitioned_join_collect(
2565 left,
2566 right,
2567 &JoinType::Right,
2568 Some(filter),
2569 task_ctx,
2570 )
2571 .await?;
2572 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2573 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2574 +----+----+----+----+----+-----+
2575 | a1 | b1 | c1 | a2 | b2 | c2 |
2576 +----+----+----+----+----+-----+
2577 | | | | 10 | 10 | 100 |
2578 | | | | 12 | 10 | 40 |
2579 | 5 | 5 | 50 | 2 | 2 | 80 |
2580 +----+----+----+----+----+-----+
2581 "));
2582
2583 assert_join_metrics!(metrics, 3);
2584
2585 Ok(())
2586 }
2587
2588 #[rstest]
2589 #[tokio::test]
2590 async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2591 let task_ctx = new_task_ctx(batch_size);
2592 let left = build_left_table();
2593 let right = build_right_table();
2594
2595 let filter = prepare_join_filter();
2596 let (columns, batches, metrics) = multi_partitioned_join_collect(
2597 left,
2598 right,
2599 &JoinType::Full,
2600 Some(filter),
2601 task_ctx,
2602 )
2603 .await?;
2604 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2605 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2606 +----+----+-----+----+----+-----+
2607 | a1 | b1 | c1 | a2 | b2 | c2 |
2608 +----+----+-----+----+----+-----+
2609 | | | | 10 | 10 | 100 |
2610 | | | | 12 | 10 | 40 |
2611 | 11 | 8 | 110 | | | |
2612 | 5 | 5 | 50 | 2 | 2 | 80 |
2613 | 9 | 8 | 90 | | | |
2614 +----+----+-----+----+----+-----+
2615 "));
2616
2617 assert_join_metrics!(metrics, 5);
2618
2619 Ok(())
2620 }
2621
2622 #[rstest]
2623 #[tokio::test]
2624 async fn join_left_semi_with_filter(
2625 #[values(1, 2, 16)] batch_size: usize,
2626 ) -> Result<()> {
2627 let task_ctx = new_task_ctx(batch_size);
2628 let left = build_left_table();
2629 let right = build_right_table();
2630
2631 let filter = prepare_join_filter();
2632 let (columns, batches, metrics) = multi_partitioned_join_collect(
2633 left,
2634 right,
2635 &JoinType::LeftSemi,
2636 Some(filter),
2637 task_ctx,
2638 )
2639 .await?;
2640 assert_eq!(columns, vec!["a1", "b1", "c1"]);
2641 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2642 +----+----+----+
2643 | a1 | b1 | c1 |
2644 +----+----+----+
2645 | 5 | 5 | 50 |
2646 +----+----+----+
2647 "));
2648
2649 assert_join_metrics!(metrics, 1);
2650
2651 Ok(())
2652 }
2653
2654 #[rstest]
2655 #[tokio::test]
2656 async fn join_left_anti_with_filter(
2657 #[values(1, 2, 16)] batch_size: usize,
2658 ) -> Result<()> {
2659 let task_ctx = new_task_ctx(batch_size);
2660 let left = build_left_table();
2661 let right = build_right_table();
2662
2663 let filter = prepare_join_filter();
2664 let (columns, batches, metrics) = multi_partitioned_join_collect(
2665 left,
2666 right,
2667 &JoinType::LeftAnti,
2668 Some(filter),
2669 task_ctx,
2670 )
2671 .await?;
2672 assert_eq!(columns, vec!["a1", "b1", "c1"]);
2673 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2674 +----+----+-----+
2675 | a1 | b1 | c1 |
2676 +----+----+-----+
2677 | 11 | 8 | 110 |
2678 | 9 | 8 | 90 |
2679 +----+----+-----+
2680 "));
2681
2682 assert_join_metrics!(metrics, 2);
2683
2684 Ok(())
2685 }
2686
2687 #[tokio::test]
2688 async fn join_has_correct_stats() -> Result<()> {
2689 let left = build_left_table();
2690 let right = build_right_table();
2691 let nested_loop_join = NestedLoopJoinExec::try_new(
2692 left,
2693 right,
2694 None,
2695 &JoinType::Left,
2696 Some(vec![1, 2]),
2697 )?;
2698 let stats = nested_loop_join.partition_statistics(None)?;
2699 assert_eq!(
2700 nested_loop_join.schema().fields().len(),
2701 stats.column_statistics.len(),
2702 );
2703 assert_eq!(2, stats.column_statistics.len());
2704 Ok(())
2705 }
2706
2707 #[rstest]
2708 #[tokio::test]
2709 async fn join_right_semi_with_filter(
2710 #[values(1, 2, 16)] batch_size: usize,
2711 ) -> Result<()> {
2712 let task_ctx = new_task_ctx(batch_size);
2713 let left = build_left_table();
2714 let right = build_right_table();
2715
2716 let filter = prepare_join_filter();
2717 let (columns, batches, metrics) = multi_partitioned_join_collect(
2718 left,
2719 right,
2720 &JoinType::RightSemi,
2721 Some(filter),
2722 task_ctx,
2723 )
2724 .await?;
2725 assert_eq!(columns, vec!["a2", "b2", "c2"]);
2726 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2727 +----+----+----+
2728 | a2 | b2 | c2 |
2729 +----+----+----+
2730 | 2 | 2 | 80 |
2731 +----+----+----+
2732 "));
2733
2734 assert_join_metrics!(metrics, 1);
2735
2736 Ok(())
2737 }
2738
2739 #[rstest]
2740 #[tokio::test]
2741 async fn join_right_anti_with_filter(
2742 #[values(1, 2, 16)] batch_size: usize,
2743 ) -> Result<()> {
2744 let task_ctx = new_task_ctx(batch_size);
2745 let left = build_left_table();
2746 let right = build_right_table();
2747
2748 let filter = prepare_join_filter();
2749 let (columns, batches, metrics) = multi_partitioned_join_collect(
2750 left,
2751 right,
2752 &JoinType::RightAnti,
2753 Some(filter),
2754 task_ctx,
2755 )
2756 .await?;
2757 assert_eq!(columns, vec!["a2", "b2", "c2"]);
2758 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2759 +----+----+-----+
2760 | a2 | b2 | c2 |
2761 +----+----+-----+
2762 | 10 | 10 | 100 |
2763 | 12 | 10 | 40 |
2764 +----+----+-----+
2765 "));
2766
2767 assert_join_metrics!(metrics, 2);
2768
2769 Ok(())
2770 }
2771
2772 #[rstest]
2773 #[tokio::test]
2774 async fn join_left_mark_with_filter(
2775 #[values(1, 2, 16)] batch_size: usize,
2776 ) -> Result<()> {
2777 let task_ctx = new_task_ctx(batch_size);
2778 let left = build_left_table();
2779 let right = build_right_table();
2780
2781 let filter = prepare_join_filter();
2782 let (columns, batches, metrics) = multi_partitioned_join_collect(
2783 left,
2784 right,
2785 &JoinType::LeftMark,
2786 Some(filter),
2787 task_ctx,
2788 )
2789 .await?;
2790 assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
2791 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2792 +----+----+-----+-------+
2793 | a1 | b1 | c1 | mark |
2794 +----+----+-----+-------+
2795 | 11 | 8 | 110 | false |
2796 | 5 | 5 | 50 | true |
2797 | 9 | 8 | 90 | false |
2798 +----+----+-----+-------+
2799 "));
2800
2801 assert_join_metrics!(metrics, 3);
2802
2803 Ok(())
2804 }
2805
2806 #[rstest]
2807 #[tokio::test]
2808 async fn join_right_mark_with_filter(
2809 #[values(1, 2, 16)] batch_size: usize,
2810 ) -> Result<()> {
2811 let task_ctx = new_task_ctx(batch_size);
2812 let left = build_left_table();
2813 let right = build_right_table();
2814
2815 let filter = prepare_join_filter();
2816 let (columns, batches, metrics) = multi_partitioned_join_collect(
2817 left,
2818 right,
2819 &JoinType::RightMark,
2820 Some(filter),
2821 task_ctx,
2822 )
2823 .await?;
2824 assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
2825
2826 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r"
2827 +----+----+-----+-------+
2828 | a2 | b2 | c2 | mark |
2829 +----+----+-----+-------+
2830 | 10 | 10 | 100 | false |
2831 | 12 | 10 | 40 | false |
2832 | 2 | 2 | 80 | true |
2833 +----+----+-----+-------+
2834 "));
2835
2836 assert_join_metrics!(metrics, 3);
2837
2838 Ok(())
2839 }
2840
2841 #[tokio::test]
2842 async fn test_overallocation() -> Result<()> {
2843 let left = build_table(
2844 ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2845 ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2846 ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2847 None,
2848 Vec::new(),
2849 );
2850 let right = build_table(
2851 ("a2", &vec![10, 11]),
2852 ("b2", &vec![12, 13]),
2853 ("c2", &vec![14, 15]),
2854 None,
2855 Vec::new(),
2856 );
2857 let filter = prepare_join_filter();
2858
2859 let join_types = vec![
2860 JoinType::Inner,
2861 JoinType::Left,
2862 JoinType::Right,
2863 JoinType::Full,
2864 JoinType::LeftSemi,
2865 JoinType::LeftAnti,
2866 JoinType::LeftMark,
2867 JoinType::RightSemi,
2868 JoinType::RightAnti,
2869 JoinType::RightMark,
2870 ];
2871
2872 for join_type in join_types {
2873 let runtime = RuntimeEnvBuilder::new()
2874 .with_memory_limit(100, 1.0)
2875 .build_arc()?;
2876 let task_ctx = TaskContext::default().with_runtime(runtime);
2877 let task_ctx = Arc::new(task_ctx);
2878
2879 let err = multi_partitioned_join_collect(
2880 Arc::clone(&left),
2881 Arc::clone(&right),
2882 &join_type,
2883 Some(filter.clone()),
2884 task_ctx,
2885 )
2886 .await
2887 .unwrap_err();
2888
2889 assert_contains!(
2890 err.to_string(),
2891 "Resources exhausted: Additional allocation failed for NestedLoopJoinLoad[0] with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]"
2892 );
2893 }
2894
2895 Ok(())
2896 }
2897
2898 fn columns(schema: &Schema) -> Vec<String> {
2900 schema.fields().iter().map(|f| f.name().clone()).collect()
2901 }
2902}