1use std::any::Any;
21use std::fmt::Formatter;
22use std::ops::{BitOr, ControlFlow};
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::task::Poll;
26
27use super::utils::{
28 asymmetric_join_output_partitioning, need_produce_result_in_final,
29 reorder_output_after_swap, swap_join_projection,
30};
31use crate::common::can_project;
32use crate::execution_plan::{boundedness_from_children, EmissionType};
33use crate::joins::utils::{
34 build_join_schema, check_join_is_valid, estimate_join_statistics,
35 need_produce_right_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
36 OnceAsync, OnceFut,
37};
38use crate::joins::SharedBitmapBuilder;
39use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricsSet};
40use crate::projection::{
41 try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData,
42 ProjectionExec,
43};
44use crate::{
45 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
46 PlanProperties, RecordBatchStream, SendableRecordBatchStream,
47};
48
49use arrow::array::{
50 new_null_array, Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions,
51};
52use arrow::buffer::BooleanBuffer;
53use arrow::compute::{concat_batches, filter, filter_record_batch, not, BatchCoalescer};
54use arrow::datatypes::{Schema, SchemaRef};
55use arrow::record_batch::RecordBatch;
56use datafusion_common::cast::as_boolean_array;
57use datafusion_common::{
58 arrow_err, internal_datafusion_err, internal_err, project_schema,
59 unwrap_or_internal_err, DataFusionError, JoinSide, Result, ScalarValue, Statistics,
60};
61use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
62use datafusion_execution::TaskContext;
63use datafusion_expr::JoinType;
64use datafusion_physical_expr::equivalence::{
65 join_equivalence_properties, ProjectionMapping,
66};
67
68use futures::{Stream, StreamExt, TryStreamExt};
69use log::debug;
70use parking_lot::Mutex;
71
72#[allow(rustdoc::private_intra_doc_links)]
73#[derive(Debug)]
167pub struct NestedLoopJoinExec {
168 pub(crate) left: Arc<dyn ExecutionPlan>,
170 pub(crate) right: Arc<dyn ExecutionPlan>,
172 pub(crate) filter: Option<JoinFilter>,
174 pub(crate) join_type: JoinType,
176 join_schema: SchemaRef,
178 build_side_data: OnceAsync<JoinLeftData>,
185 column_indices: Vec<ColumnIndex>,
187 projection: Option<Vec<usize>>,
189
190 metrics: ExecutionPlanMetricsSet,
192 cache: PlanProperties,
194}
195
196impl NestedLoopJoinExec {
197 pub fn try_new(
199 left: Arc<dyn ExecutionPlan>,
200 right: Arc<dyn ExecutionPlan>,
201 filter: Option<JoinFilter>,
202 join_type: &JoinType,
203 projection: Option<Vec<usize>>,
204 ) -> Result<Self> {
205 let left_schema = left.schema();
206 let right_schema = right.schema();
207 check_join_is_valid(&left_schema, &right_schema, &[])?;
208 let (join_schema, column_indices) =
209 build_join_schema(&left_schema, &right_schema, join_type);
210 let join_schema = Arc::new(join_schema);
211 let cache = Self::compute_properties(
212 &left,
213 &right,
214 Arc::clone(&join_schema),
215 *join_type,
216 projection.as_ref(),
217 )?;
218
219 Ok(NestedLoopJoinExec {
220 left,
221 right,
222 filter,
223 join_type: *join_type,
224 join_schema,
225 build_side_data: Default::default(),
226 column_indices,
227 projection,
228 metrics: Default::default(),
229 cache,
230 })
231 }
232
233 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
235 &self.left
236 }
237
238 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
240 &self.right
241 }
242
243 pub fn filter(&self) -> Option<&JoinFilter> {
245 self.filter.as_ref()
246 }
247
248 pub fn join_type(&self) -> &JoinType {
250 &self.join_type
251 }
252
253 pub fn projection(&self) -> Option<&Vec<usize>> {
254 self.projection.as_ref()
255 }
256
257 fn compute_properties(
259 left: &Arc<dyn ExecutionPlan>,
260 right: &Arc<dyn ExecutionPlan>,
261 schema: SchemaRef,
262 join_type: JoinType,
263 projection: Option<&Vec<usize>>,
264 ) -> Result<PlanProperties> {
265 let mut eq_properties = join_equivalence_properties(
267 left.equivalence_properties().clone(),
268 right.equivalence_properties().clone(),
269 &join_type,
270 Arc::clone(&schema),
271 &Self::maintains_input_order(join_type),
272 None,
273 &[],
275 )?;
276
277 let mut output_partitioning =
278 asymmetric_join_output_partitioning(left, right, &join_type)?;
279
280 let emission_type = if left.boundedness().is_unbounded() {
281 EmissionType::Final
282 } else if right.pipeline_behavior() == EmissionType::Incremental {
283 match join_type {
284 JoinType::Inner
287 | JoinType::LeftSemi
288 | JoinType::RightSemi
289 | JoinType::Right
290 | JoinType::RightAnti
291 | JoinType::RightMark => EmissionType::Incremental,
292 JoinType::Left
295 | JoinType::LeftAnti
296 | JoinType::LeftMark
297 | JoinType::Full => EmissionType::Both,
298 }
299 } else {
300 right.pipeline_behavior()
301 };
302
303 if let Some(projection) = projection {
304 let projection_mapping =
306 ProjectionMapping::from_indices(projection, &schema)?;
307 let out_schema = project_schema(&schema, Some(projection))?;
308 output_partitioning =
309 output_partitioning.project(&projection_mapping, &eq_properties);
310 eq_properties = eq_properties.project(&projection_mapping, out_schema);
311 }
312
313 Ok(PlanProperties::new(
314 eq_properties,
315 output_partitioning,
316 emission_type,
317 boundedness_from_children([left, right]),
318 ))
319 }
320
321 fn maintains_input_order(_join_type: JoinType) -> Vec<bool> {
323 vec![false, false]
324 }
325
326 pub fn contains_projection(&self) -> bool {
327 self.projection.is_some()
328 }
329
330 pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
331 can_project(&self.schema(), projection.as_ref())?;
333 let projection = match projection {
334 Some(projection) => match &self.projection {
335 Some(p) => Some(projection.iter().map(|i| p[*i]).collect()),
336 None => Some(projection),
337 },
338 None => None,
339 };
340 Self::try_new(
341 Arc::clone(&self.left),
342 Arc::clone(&self.right),
343 self.filter.clone(),
344 &self.join_type,
345 projection,
346 )
347 }
348
349 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
358 let left = self.left();
359 let right = self.right();
360 let new_join = NestedLoopJoinExec::try_new(
361 Arc::clone(right),
362 Arc::clone(left),
363 self.filter().map(JoinFilter::swap),
364 &self.join_type().swap(),
365 swap_join_projection(
366 left.schema().fields().len(),
367 right.schema().fields().len(),
368 self.projection.as_ref(),
369 self.join_type(),
370 ),
371 )?;
372
373 let plan: Arc<dyn ExecutionPlan> = if matches!(
376 self.join_type(),
377 JoinType::LeftSemi
378 | JoinType::RightSemi
379 | JoinType::LeftAnti
380 | JoinType::RightAnti
381 ) || self.projection.is_some()
382 {
383 Arc::new(new_join)
384 } else {
385 reorder_output_after_swap(
386 Arc::new(new_join),
387 &self.left().schema(),
388 &self.right().schema(),
389 )?
390 };
391
392 Ok(plan)
393 }
394}
395
396impl DisplayAs for NestedLoopJoinExec {
397 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
398 match t {
399 DisplayFormatType::Default | DisplayFormatType::Verbose => {
400 let display_filter = self.filter.as_ref().map_or_else(
401 || "".to_string(),
402 |f| format!(", filter={}", f.expression()),
403 );
404 let display_projections = if self.contains_projection() {
405 format!(
406 ", projection=[{}]",
407 self.projection
408 .as_ref()
409 .unwrap()
410 .iter()
411 .map(|index| format!(
412 "{}@{}",
413 self.join_schema.fields().get(*index).unwrap().name(),
414 index
415 ))
416 .collect::<Vec<_>>()
417 .join(", ")
418 )
419 } else {
420 "".to_string()
421 };
422 write!(
423 f,
424 "NestedLoopJoinExec: join_type={:?}{}{}",
425 self.join_type, display_filter, display_projections
426 )
427 }
428 DisplayFormatType::TreeRender => {
429 if *self.join_type() != JoinType::Inner {
430 writeln!(f, "join_type={:?}", self.join_type)
431 } else {
432 Ok(())
433 }
434 }
435 }
436 }
437}
438
439impl ExecutionPlan for NestedLoopJoinExec {
440 fn name(&self) -> &'static str {
441 "NestedLoopJoinExec"
442 }
443
444 fn as_any(&self) -> &dyn Any {
445 self
446 }
447
448 fn properties(&self) -> &PlanProperties {
449 &self.cache
450 }
451
452 fn required_input_distribution(&self) -> Vec<Distribution> {
453 vec![
454 Distribution::SinglePartition,
455 Distribution::UnspecifiedDistribution,
456 ]
457 }
458
459 fn maintains_input_order(&self) -> Vec<bool> {
460 Self::maintains_input_order(self.join_type)
461 }
462
463 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
464 vec![&self.left, &self.right]
465 }
466
467 fn with_new_children(
468 self: Arc<Self>,
469 children: Vec<Arc<dyn ExecutionPlan>>,
470 ) -> Result<Arc<dyn ExecutionPlan>> {
471 Ok(Arc::new(NestedLoopJoinExec::try_new(
472 Arc::clone(&children[0]),
473 Arc::clone(&children[1]),
474 self.filter.clone(),
475 &self.join_type,
476 self.projection.clone(),
477 )?))
478 }
479
480 fn execute(
481 &self,
482 partition: usize,
483 context: Arc<TaskContext>,
484 ) -> Result<SendableRecordBatchStream> {
485 if self.left.output_partitioning().partition_count() != 1 {
486 return internal_err!(
487 "Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
488 consider using CoalescePartitionsExec or the EnforceDistribution rule"
489 );
490 }
491
492 let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
493
494 let load_reservation =
496 MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
497 .register(context.memory_pool());
498
499 let build_side_data = self.build_side_data.try_once(|| {
500 let stream = self.left.execute(0, Arc::clone(&context))?;
501
502 Ok(collect_left_input(
503 stream,
504 join_metrics.clone(),
505 load_reservation,
506 need_produce_result_in_final(self.join_type),
507 self.right().output_partitioning().partition_count(),
508 ))
509 })?;
510
511 let batch_size = context.session_config().batch_size();
512
513 let probe_side_data = self.right.execute(partition, context)?;
514
515 let column_indices_after_projection = match &self.projection {
517 Some(projection) => projection
518 .iter()
519 .map(|i| self.column_indices[*i].clone())
520 .collect(),
521 None => self.column_indices.clone(),
522 };
523
524 Ok(Box::pin(NestedLoopJoinStream::new(
525 self.schema(),
526 self.filter.clone(),
527 self.join_type,
528 probe_side_data,
529 build_side_data,
530 column_indices_after_projection,
531 join_metrics,
532 batch_size,
533 )))
534 }
535
536 fn metrics(&self) -> Option<MetricsSet> {
537 Some(self.metrics.clone_inner())
538 }
539
540 fn statistics(&self) -> Result<Statistics> {
541 self.partition_statistics(None)
542 }
543
544 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
545 if partition.is_some() {
546 return Ok(Statistics::new_unknown(&self.schema()));
547 }
548 estimate_join_statistics(
549 self.left.partition_statistics(None)?,
550 self.right.partition_statistics(None)?,
551 vec![],
552 &self.join_type,
553 &self.join_schema,
554 )
555 }
556
557 fn try_swapping_with_projection(
561 &self,
562 projection: &ProjectionExec,
563 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
564 if self.contains_projection() {
566 return Ok(None);
567 }
568
569 if let Some(JoinData {
570 projected_left_child,
571 projected_right_child,
572 join_filter,
573 ..
574 }) = try_pushdown_through_join(
575 projection,
576 self.left(),
577 self.right(),
578 &[],
579 self.schema(),
580 self.filter(),
581 )? {
582 Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
583 Arc::new(projected_left_child),
584 Arc::new(projected_right_child),
585 join_filter,
586 self.join_type(),
587 None,
589 )?)))
590 } else {
591 try_embed_projection(projection, self)
592 }
593 }
594}
595
596impl EmbeddedProjection for NestedLoopJoinExec {
597 fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
598 self.with_projection(projection)
599 }
600}
601
602pub(crate) struct JoinLeftData {
604 batch: RecordBatch,
606 bitmap: SharedBitmapBuilder,
608 probe_threads_counter: AtomicUsize,
610 #[expect(dead_code)]
614 reservation: MemoryReservation,
615}
616
617impl JoinLeftData {
618 pub(crate) fn new(
619 batch: RecordBatch,
620 bitmap: SharedBitmapBuilder,
621 probe_threads_counter: AtomicUsize,
622 reservation: MemoryReservation,
623 ) -> Self {
624 Self {
625 batch,
626 bitmap,
627 probe_threads_counter,
628 reservation,
629 }
630 }
631
632 pub(crate) fn batch(&self) -> &RecordBatch {
633 &self.batch
634 }
635
636 pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder {
637 &self.bitmap
638 }
639
640 pub(crate) fn report_probe_completed(&self) -> bool {
643 self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
644 }
645}
646
647async fn collect_left_input(
649 stream: SendableRecordBatchStream,
650 join_metrics: BuildProbeJoinMetrics,
651 reservation: MemoryReservation,
652 with_visited_left_side: bool,
653 probe_threads_count: usize,
654) -> Result<JoinLeftData> {
655 let schema = stream.schema();
656
657 let (batches, metrics, mut reservation) = stream
659 .try_fold(
660 (Vec::new(), join_metrics, reservation),
661 |(mut batches, metrics, mut reservation), batch| async {
662 let batch_size = batch.get_array_memory_size();
663 reservation.try_grow(batch_size)?;
665 metrics.build_mem_used.add(batch_size);
667 metrics.build_input_batches.add(1);
668 metrics.build_input_rows.add(batch.num_rows());
669 batches.push(batch);
671 Ok((batches, metrics, reservation))
672 },
673 )
674 .await?;
675
676 let merged_batch = concat_batches(&schema, &batches)?;
677
678 let visited_left_side = if with_visited_left_side {
680 let n_rows = merged_batch.num_rows();
681 let buffer_size = n_rows.div_ceil(8);
682 reservation.try_grow(buffer_size)?;
683 metrics.build_mem_used.add(buffer_size);
684
685 let mut buffer = BooleanBufferBuilder::new(n_rows);
686 buffer.append_n(n_rows, false);
687 buffer
688 } else {
689 BooleanBufferBuilder::new(0)
690 };
691
692 Ok(JoinLeftData::new(
693 merged_batch,
694 Mutex::new(visited_left_side),
695 AtomicUsize::new(probe_threads_count),
696 reservation,
697 ))
698}
699
700#[derive(Debug, Clone, Copy)]
703enum NLJState {
704 BufferingLeft,
705 FetchingRight,
706 ProbeRight,
707 EmitRightUnmatched,
708 EmitLeftUnmatched,
709 Done,
710}
711pub(crate) struct NestedLoopJoinStream {
712 pub(crate) output_schema: Arc<Schema>,
723 pub(crate) join_filter: Option<JoinFilter>,
725 pub(crate) join_type: JoinType,
727 pub(crate) right_data: SendableRecordBatchStream,
729 pub(crate) left_data: OnceFut<JoinLeftData>,
731 pub(crate) column_indices: Vec<ColumnIndex>,
744 pub(crate) join_metrics: BuildProbeJoinMetrics,
746
747 batch_size: usize,
749
750 should_track_unmatched_right: bool,
752
753 state: NLJState,
759 output_buffer: Box<BatchCoalescer>,
762 handled_empty_output: bool,
764
765 buffered_left_data: Option<Arc<JoinLeftData>>,
769 left_probe_idx: usize,
771 left_emit_idx: usize,
773 left_exhausted: bool,
776 #[allow(dead_code)]
779 left_buffered_in_one_pass: bool,
780
781 current_right_batch: Option<RecordBatch>,
785 current_right_batch_matched: Option<BooleanArray>,
788}
789
790impl Stream for NestedLoopJoinStream {
791 type Item = Result<RecordBatch>;
792
793 fn poll_next(
824 mut self: std::pin::Pin<&mut Self>,
825 cx: &mut std::task::Context<'_>,
826 ) -> Poll<Option<Self::Item>> {
827 loop {
828 match self.state {
829 NLJState::BufferingLeft => {
835 debug!("[NLJState] Entering: {:?}", self.state);
836 let build_metric = self.join_metrics.build_time.clone();
841 let _build_timer = build_metric.timer();
842
843 match self.handle_buffering_left(cx) {
844 ControlFlow::Continue(()) => continue,
845 ControlFlow::Break(poll) => return poll,
846 }
847 }
848
849 NLJState::FetchingRight => {
872 debug!("[NLJState] Entering: {:?}", self.state);
873 let join_metric = self.join_metrics.join_time.clone();
875 let _join_timer = join_metric.timer();
876
877 match self.handle_fetching_right(cx) {
878 ControlFlow::Continue(()) => continue,
879 ControlFlow::Break(poll) => return poll,
880 }
881 }
882
883 NLJState::ProbeRight => {
898 debug!("[NLJState] Entering: {:?}", self.state);
899
900 let join_metric = self.join_metrics.join_time.clone();
902 let _join_timer = join_metric.timer();
903
904 match self.handle_probe_right() {
905 ControlFlow::Continue(()) => continue,
906 ControlFlow::Break(poll) => {
907 return self.join_metrics.baseline.record_poll(poll)
908 }
909 }
910 }
911
912 NLJState::EmitRightUnmatched => {
919 debug!("[NLJState] Entering: {:?}", self.state);
920
921 let join_metric = self.join_metrics.join_time.clone();
923 let _join_timer = join_metric.timer();
924
925 match self.handle_emit_right_unmatched() {
926 ControlFlow::Continue(()) => continue,
927 ControlFlow::Break(poll) => {
928 return self.join_metrics.baseline.record_poll(poll)
929 }
930 }
931 }
932
933 NLJState::EmitLeftUnmatched => {
949 debug!("[NLJState] Entering: {:?}", self.state);
950
951 let join_metric = self.join_metrics.join_time.clone();
953 let _join_timer = join_metric.timer();
954
955 match self.handle_emit_left_unmatched() {
956 ControlFlow::Continue(()) => continue,
957 ControlFlow::Break(poll) => {
958 return self.join_metrics.baseline.record_poll(poll)
959 }
960 }
961 }
962
963 NLJState::Done => {
965 debug!("[NLJState] Entering: {:?}", self.state);
966
967 let join_metric = self.join_metrics.join_time.clone();
969 let _join_timer = join_metric.timer();
970 let poll = self.handle_done();
974 return self.join_metrics.baseline.record_poll(poll);
975 }
976 }
977 }
978 }
979}
980
981impl RecordBatchStream for NestedLoopJoinStream {
982 fn schema(&self) -> SchemaRef {
983 Arc::clone(&self.output_schema)
984 }
985}
986
987impl NestedLoopJoinStream {
988 #[allow(clippy::too_many_arguments)]
989 pub(crate) fn new(
990 schema: Arc<Schema>,
991 filter: Option<JoinFilter>,
992 join_type: JoinType,
993 right_data: SendableRecordBatchStream,
994 left_data: OnceFut<JoinLeftData>,
995 column_indices: Vec<ColumnIndex>,
996 join_metrics: BuildProbeJoinMetrics,
997 batch_size: usize,
998 ) -> Self {
999 Self {
1000 output_schema: Arc::clone(&schema),
1001 join_filter: filter,
1002 join_type,
1003 right_data,
1004 column_indices,
1005 left_data,
1006 join_metrics,
1007 buffered_left_data: None,
1008 output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)),
1009 batch_size,
1010 current_right_batch: None,
1011 current_right_batch_matched: None,
1012 state: NLJState::BufferingLeft,
1013 left_probe_idx: 0,
1014 left_emit_idx: 0,
1015 left_exhausted: false,
1016 left_buffered_in_one_pass: true,
1017 handled_empty_output: false,
1018 should_track_unmatched_right: need_produce_right_in_final(join_type),
1019 }
1020 }
1021
1022 fn handle_buffering_left(
1026 &mut self,
1027 cx: &mut std::task::Context<'_>,
1028 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1029 match self.left_data.get_shared(cx) {
1030 Poll::Ready(Ok(left_data)) => {
1031 self.buffered_left_data = Some(left_data);
1032 self.left_exhausted = true;
1034 self.state = NLJState::FetchingRight;
1035 ControlFlow::Continue(())
1037 }
1038 Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1039 Poll::Pending => ControlFlow::Break(Poll::Pending),
1040 }
1041 }
1042
1043 fn handle_fetching_right(
1045 &mut self,
1046 cx: &mut std::task::Context<'_>,
1047 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1048 match self.right_data.poll_next_unpin(cx) {
1049 Poll::Ready(result) => match result {
1050 Some(Ok(right_batch)) => {
1051 let right_batch_size = right_batch.num_rows();
1053 self.join_metrics.input_rows.add(right_batch_size);
1054 self.join_metrics.input_batches.add(1);
1055
1056 if right_batch_size == 0 {
1058 return ControlFlow::Continue(());
1059 }
1060
1061 self.current_right_batch = Some(right_batch);
1062
1063 if self.should_track_unmatched_right {
1065 let zeroed_buf = BooleanBuffer::new_unset(right_batch_size);
1066 self.current_right_batch_matched =
1067 Some(BooleanArray::new(zeroed_buf, None));
1068 }
1069
1070 self.left_probe_idx = 0;
1071 self.state = NLJState::ProbeRight;
1072 ControlFlow::Continue(())
1073 }
1074 Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1075 None => {
1076 self.state = NLJState::EmitLeftUnmatched;
1078 ControlFlow::Continue(())
1079 }
1080 },
1081 Poll::Pending => ControlFlow::Break(Poll::Pending),
1082 }
1083 }
1084
1085 fn handle_probe_right(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1087 if let Some(poll) = self.maybe_flush_ready_batch() {
1089 return ControlFlow::Break(poll);
1090 }
1091
1092 match self.process_probe_batch() {
1094 Ok(true) => ControlFlow::Continue(()),
1098 Ok(false) => {
1102 self.left_probe_idx = 0;
1104 if self.should_track_unmatched_right {
1105 debug_assert!(
1106 self.current_right_batch_matched.is_some(),
1107 "If it's required to track matched rows in the right input, the right bitmap must be present"
1108 );
1109 self.state = NLJState::EmitRightUnmatched;
1110 } else {
1111 self.current_right_batch = None;
1112 self.state = NLJState::FetchingRight;
1113 }
1114 ControlFlow::Continue(())
1115 }
1116 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1117 }
1118 }
1119
1120 fn handle_emit_right_unmatched(
1122 &mut self,
1123 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1124 if let Some(poll) = self.maybe_flush_ready_batch() {
1126 return ControlFlow::Break(poll);
1127 }
1128
1129 debug_assert!(
1130 self.current_right_batch_matched.is_some()
1131 && self.current_right_batch.is_some(),
1132 "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present"
1133 );
1134
1135 match self.process_right_unmatched() {
1137 Ok(Some(batch)) => {
1138 match self.output_buffer.push_batch(batch) {
1139 Ok(()) => {
1140 debug_assert!(self.current_right_batch.is_none());
1143 self.state = NLJState::FetchingRight;
1144 ControlFlow::Continue(())
1145 }
1146 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1147 }
1148 }
1149 Ok(None) => {
1150 debug_assert!(self.current_right_batch.is_none());
1153 self.state = NLJState::FetchingRight;
1154 ControlFlow::Continue(())
1155 }
1156 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1157 }
1158 }
1159
1160 fn handle_emit_left_unmatched(
1162 &mut self,
1163 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1164 if let Some(poll) = self.maybe_flush_ready_batch() {
1166 return ControlFlow::Break(poll);
1167 }
1168
1169 match self.process_left_unmatched() {
1171 Ok(true) => ControlFlow::Continue(()),
1174 Ok(false) => match self.output_buffer.finish_buffered_batch() {
1177 Ok(()) => {
1178 self.state = NLJState::Done;
1179 ControlFlow::Continue(())
1180 }
1181 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1182 },
1183 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1184 }
1185 }
1186
1187 fn handle_done(&mut self) -> Poll<Option<Result<RecordBatch>>> {
1189 if let Some(poll) = self.maybe_flush_ready_batch() {
1191 return poll;
1192 }
1193
1194 if !self.handled_empty_output {
1200 let zero_count = Count::new();
1201 if *self.join_metrics.baseline.output_rows() == zero_count {
1202 let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema));
1203 self.handled_empty_output = true;
1204 return Poll::Ready(Some(Ok(empty_batch)));
1205 }
1206 }
1207
1208 Poll::Ready(None)
1209 }
1210
1211 fn process_probe_batch(&mut self) -> Result<bool> {
1218 let left_data = Arc::clone(self.get_left_data()?);
1219 let right_batch = self
1220 .current_right_batch
1221 .as_ref()
1222 .ok_or_else(|| internal_datafusion_err!("Right batch should be available"))?
1223 .clone();
1224
1225 if self.left_probe_idx >= left_data.batch().num_rows() {
1227 return Ok(false);
1228 }
1229
1230 let l_idx = self.left_probe_idx;
1236 let join_batch =
1237 self.process_single_left_row_join(&left_data, &right_batch, l_idx)?;
1238
1239 if let Some(batch) = join_batch {
1240 self.output_buffer.push_batch(batch)?;
1241 }
1242
1243 self.left_probe_idx += 1;
1247
1248 Ok(true)
1250 }
1251
1252 fn process_single_left_row_join(
1255 &mut self,
1256 left_data: &JoinLeftData,
1257 right_batch: &RecordBatch,
1258 l_index: usize,
1259 ) -> Result<Option<RecordBatch>> {
1260 let right_row_count = right_batch.num_rows();
1261 if right_row_count == 0 {
1262 return Ok(None);
1263 }
1264
1265 let cur_right_bitmap = if let Some(filter) = &self.join_filter {
1266 apply_filter_to_row_join_batch(
1267 left_data.batch(),
1268 l_index,
1269 right_batch,
1270 filter,
1271 )?
1272 } else {
1273 BooleanArray::from(vec![true; right_row_count])
1274 };
1275
1276 self.update_matched_bitmap(l_index, &cur_right_bitmap)?;
1277
1278 if matches!(
1281 self.join_type,
1282 JoinType::LeftAnti
1283 | JoinType::LeftSemi
1284 | JoinType::LeftMark
1285 | JoinType::RightAnti
1286 | JoinType::RightMark
1287 | JoinType::RightSemi
1288 ) {
1289 return Ok(None);
1290 }
1291
1292 if cur_right_bitmap.true_count() == 0 {
1293 Ok(None)
1295 } else {
1296 let join_batch = build_row_join_batch(
1298 &self.output_schema,
1299 left_data.batch(),
1300 l_index,
1301 right_batch,
1302 Some(cur_right_bitmap),
1303 &self.column_indices,
1304 JoinSide::Left,
1305 )?;
1306 Ok(join_batch)
1307 }
1308 }
1309
1310 fn process_left_unmatched(&mut self) -> Result<bool> {
1314 let left_data = self.get_left_data()?;
1315 let left_batch = left_data.batch();
1316
1317 let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
1323 let handled_by_other_partition =
1325 self.left_emit_idx == 0 && !left_data.report_probe_completed();
1326 let finished = self.left_emit_idx >= left_batch.num_rows();
1328
1329 if join_type_no_produce_left || handled_by_other_partition || finished {
1330 return Ok(false);
1331 }
1332
1333 let start_idx = self.left_emit_idx;
1338 let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
1339
1340 if let Some(batch) =
1341 self.process_left_unmatched_range(left_data, start_idx, end_idx)?
1342 {
1343 self.output_buffer.push_batch(batch)?;
1344 }
1345
1346 self.left_emit_idx = end_idx;
1348
1349 Ok(true)
1351 }
1352
1353 fn process_left_unmatched_range(
1366 &self,
1367 left_data: &JoinLeftData,
1368 start_idx: usize,
1369 end_idx: usize,
1370 ) -> Result<Option<RecordBatch>> {
1371 if start_idx == end_idx {
1372 return Ok(None);
1373 }
1374
1375 let left_batch = left_data.batch();
1378 let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx);
1379
1380 let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx);
1382 bitmap_sliced.append_n(end_idx - start_idx, false);
1383 let bitmap = left_data.bitmap().lock();
1384 for i in start_idx..end_idx {
1385 assert!(
1386 i - start_idx < bitmap_sliced.capacity(),
1387 "DBG: {start_idx}, {end_idx}"
1388 );
1389 bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i));
1390 }
1391 let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None);
1392
1393 build_unmatched_batch(
1394 Arc::clone(&self.output_schema),
1395 &left_batch_sliced,
1396 bitmap_sliced,
1397 self.right_data.schema(),
1398 &self.column_indices,
1399 self.join_type,
1400 JoinSide::Left,
1401 )
1402 }
1403
1404 fn process_right_unmatched(&mut self) -> Result<Option<RecordBatch>> {
1407 let right_batch_bitmap: BooleanArray =
1409 std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| {
1410 internal_datafusion_err!("right bitmap should be available")
1411 })?;
1412
1413 let right_batch = self.current_right_batch.take();
1414 let cur_right_batch = unwrap_or_internal_err!(right_batch);
1415
1416 let left_data = self.get_left_data()?;
1417 let left_schema = left_data.batch().schema();
1418
1419 let res = build_unmatched_batch(
1420 Arc::clone(&self.output_schema),
1421 &cur_right_batch,
1422 right_batch_bitmap,
1423 left_schema,
1424 &self.column_indices,
1425 self.join_type,
1426 JoinSide::Right,
1427 );
1428
1429 self.current_right_batch_matched = None;
1431
1432 res
1433 }
1434
1435 fn get_left_data(&self) -> Result<&Arc<JoinLeftData>> {
1439 self.buffered_left_data
1440 .as_ref()
1441 .ok_or_else(|| internal_datafusion_err!("LeftData should be available"))
1442 }
1443
1444 fn maybe_flush_ready_batch(&mut self) -> Option<Poll<Option<Result<RecordBatch>>>> {
1447 if self.output_buffer.has_completed_batch() {
1448 if let Some(batch) = self.output_buffer.next_completed_batch() {
1449 self.join_metrics.output_batches.add(1);
1452
1453 return Some(Poll::Ready(Some(Ok(batch))));
1454 }
1455 }
1456
1457 None
1458 }
1459
1460 fn update_matched_bitmap(
1476 &mut self,
1477 l_index: usize,
1478 r_matched_bitmap: &BooleanArray,
1479 ) -> Result<()> {
1480 let left_data = self.get_left_data()?;
1481
1482 let joined_len = r_matched_bitmap.true_count();
1484
1485 if need_produce_result_in_final(self.join_type) && (joined_len > 0) {
1487 let mut bitmap = left_data.bitmap().lock();
1488 bitmap.set_bit(l_index, true);
1489 }
1490
1491 if self.should_track_unmatched_right {
1493 debug_assert!(self.current_right_batch_matched.is_some());
1494 let right_bitmap = std::mem::take(&mut self.current_right_batch_matched)
1496 .ok_or_else(|| {
1497 internal_datafusion_err!("right batch's bitmap should be present")
1498 })?;
1499 let (buf, nulls) = right_bitmap.into_parts();
1500 debug_assert!(nulls.is_none());
1501 let updated_right_bitmap = buf.bitor(r_matched_bitmap.values());
1502
1503 self.current_right_batch_matched =
1504 Some(BooleanArray::new(updated_right_bitmap, None));
1505 }
1506
1507 Ok(())
1508 }
1509}
1510
1511fn apply_filter_to_row_join_batch(
1517 left_batch: &RecordBatch,
1518 l_index: usize,
1519 right_batch: &RecordBatch,
1520 filter: &JoinFilter,
1521) -> Result<BooleanArray> {
1522 debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0);
1523
1524 let intermediate_batch = if filter.schema.fields().is_empty() {
1525 create_record_batch_with_empty_schema(
1528 Arc::new((*filter.schema).clone()),
1529 right_batch.num_rows(),
1530 )?
1531 } else {
1532 build_row_join_batch(
1533 &filter.schema,
1534 left_batch,
1535 l_index,
1536 right_batch,
1537 None,
1538 &filter.column_indices,
1539 JoinSide::Left,
1540 )?
1541 .ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))?
1542 };
1543
1544 let filter_result = filter
1545 .expression()
1546 .evaluate(&intermediate_batch)?
1547 .into_array(intermediate_batch.num_rows())?;
1548 let filter_arr = as_boolean_array(&filter_result)?;
1549
1550 let (is_filtered, nulls) = filter_arr.clone().into_parts();
1555 let bitmap_combined = match nulls {
1556 Some(nulls) => {
1557 let combined = nulls.inner() & &is_filtered;
1558 BooleanArray::new(combined, None)
1559 }
1560 None => BooleanArray::new(is_filtered, None),
1561 };
1562
1563 Ok(bitmap_combined)
1564}
1565
1566fn build_row_join_batch(
1614 output_schema: &Schema,
1615 build_side_batch: &RecordBatch,
1616 build_side_index: usize,
1617 probe_side_batch: &RecordBatch,
1618 probe_side_filter: Option<BooleanArray>,
1619 col_indices: &[ColumnIndex],
1621 build_side: JoinSide,
1624) -> Result<Option<RecordBatch>> {
1625 debug_assert!(build_side != JoinSide::None);
1626
1627 let filtered_probe_batch = if let Some(filter) = probe_side_filter {
1630 &filter_record_batch(probe_side_batch, &filter)?
1631 } else {
1632 probe_side_batch
1633 };
1634
1635 if filtered_probe_batch.num_rows() == 0 {
1636 return Ok(None);
1637 }
1638
1639 if output_schema.fields.is_empty() {
1647 return Ok(Some(create_record_batch_with_empty_schema(
1648 Arc::new(output_schema.clone()),
1649 filtered_probe_batch.num_rows(),
1650 )?));
1651 }
1652
1653 let mut columns: Vec<Arc<dyn Array>> =
1654 Vec::with_capacity(output_schema.fields().len());
1655
1656 for column_index in col_indices {
1657 let array = if column_index.side == build_side {
1658 let original_left_array = build_side_batch.column(column_index.index);
1661 let scalar_value = ScalarValue::try_from_array(
1662 original_left_array.as_ref(),
1663 build_side_index,
1664 )?;
1665 scalar_value.to_array_of_size(filtered_probe_batch.num_rows())?
1666 } else {
1667 Arc::clone(filtered_probe_batch.column(column_index.index))
1669 };
1670
1671 columns.push(array);
1672 }
1673
1674 Ok(Some(RecordBatch::try_new(
1675 Arc::new(output_schema.clone()),
1676 columns,
1677 )?))
1678}
1679
1680fn build_unmatched_batch_empty_schema(
1687 output_schema: SchemaRef,
1688 batch_bitmap: &BooleanArray,
1689 join_type: JoinType,
1691) -> Result<Option<RecordBatch>> {
1692 let result_size = match join_type {
1693 JoinType::Left
1694 | JoinType::Right
1695 | JoinType::Full
1696 | JoinType::LeftAnti
1697 | JoinType::RightAnti => batch_bitmap.false_count(),
1698 JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(),
1699 JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(),
1700 _ => unreachable!(),
1701 };
1702
1703 if output_schema.fields().is_empty() {
1704 Ok(Some(create_record_batch_with_empty_schema(
1705 Arc::clone(&output_schema),
1706 result_size,
1707 )?))
1708 } else {
1709 Ok(None)
1710 }
1711}
1712
1713fn create_record_batch_with_empty_schema(
1717 schema: SchemaRef,
1718 row_count: usize,
1719) -> Result<RecordBatch> {
1720 let options = RecordBatchOptions::new()
1721 .with_match_field_names(true)
1722 .with_row_count(Some(row_count));
1723
1724 RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| {
1725 internal_datafusion_err!("Failed to create empty record batch: {}", e)
1726 })
1727}
1728
1729fn build_unmatched_batch(
1765 output_schema: SchemaRef,
1766 batch: &RecordBatch,
1767 batch_bitmap: BooleanArray,
1768 another_side_schema: SchemaRef,
1770 col_indices: &[ColumnIndex],
1771 join_type: JoinType,
1772 batch_side: JoinSide,
1773) -> Result<Option<RecordBatch>> {
1774 debug_assert_ne!(join_type, JoinType::Inner);
1776 debug_assert_ne!(batch_side, JoinSide::None);
1777
1778 if let Some(batch) = build_unmatched_batch_empty_schema(
1780 Arc::clone(&output_schema),
1781 &batch_bitmap,
1782 join_type,
1783 )? {
1784 return Ok(Some(batch));
1785 }
1786
1787 match join_type {
1788 JoinType::Full | JoinType::Right | JoinType::Left => {
1789 if join_type == JoinType::Right {
1790 debug_assert_eq!(batch_side, JoinSide::Right);
1791 }
1792 if join_type == JoinType::Left {
1793 debug_assert_eq!(batch_side, JoinSide::Left);
1794 }
1795
1796 let flipped_bitmap = not(&batch_bitmap)?;
1799
1800 let left_null_columns: Vec<Arc<dyn Array>> = another_side_schema
1802 .fields()
1803 .iter()
1804 .map(|field| new_null_array(field.data_type(), 1))
1805 .collect();
1806
1807 let nullable_left_schema = Arc::new(Schema::new(
1811 another_side_schema
1812 .fields()
1813 .iter()
1814 .map(|field| {
1815 (**field).clone().with_nullable(true)
1816 })
1817 .collect::<Vec<_>>(),
1818 ));
1819 let left_null_batch = if nullable_left_schema.fields.is_empty() {
1820 create_record_batch_with_empty_schema(nullable_left_schema, 0)?
1823 } else {
1824 RecordBatch::try_new(nullable_left_schema, left_null_columns)?
1825 };
1826
1827 debug_assert_ne!(batch_side, JoinSide::None);
1828 let opposite_side = batch_side.negate();
1829
1830 build_row_join_batch(&output_schema, &left_null_batch, 0, batch, Some(flipped_bitmap), col_indices, opposite_side)
1831
1832 },
1833 JoinType::RightSemi | JoinType::RightAnti | JoinType::LeftSemi | JoinType::LeftAnti => {
1834 if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) {
1835 debug_assert_eq!(batch_side, JoinSide::Right);
1836 }
1837 if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
1838 debug_assert_eq!(batch_side, JoinSide::Left);
1839 }
1840
1841 let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi) {
1842 batch_bitmap.clone()
1843 } else {
1844 not(&batch_bitmap)?
1845 };
1846
1847 if bitmap.true_count() == 0 {
1848 return Ok(None);
1849 }
1850
1851 let mut columns: Vec<Arc<dyn Array>> =
1852 Vec::with_capacity(output_schema.fields().len());
1853
1854 for column_index in col_indices {
1855 debug_assert!(column_index.side == batch_side);
1856
1857 let col = batch.column(column_index.index);
1858 let filtered_col = filter(col, &bitmap)?;
1859
1860 columns.push(filtered_col);
1861 }
1862
1863 Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1864 },
1865 JoinType::RightMark | JoinType::LeftMark => {
1866 if join_type == JoinType::RightMark {
1867 debug_assert_eq!(batch_side, JoinSide::Right);
1868 }
1869 if join_type == JoinType::LeftMark {
1870 debug_assert_eq!(batch_side, JoinSide::Left);
1871 }
1872
1873 let mut columns: Vec<Arc<dyn Array>> =
1874 Vec::with_capacity(output_schema.fields().len());
1875
1876 let mut right_batch_bitmap_opt = Some(batch_bitmap);
1878
1879 for column_index in col_indices {
1880 if column_index.side == batch_side {
1881 let col = batch.column(column_index.index);
1882
1883 columns.push(Arc::clone(col));
1884 } else if column_index.side == JoinSide::None {
1885 let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt);
1886 match right_batch_bitmap {
1887 Some(right_batch_bitmap) => {columns.push(Arc::new(right_batch_bitmap))},
1888 None => unreachable!("Should only be one mark column"),
1889 }
1890 } else {
1891 return internal_err!("Not possible to have this join side for RightMark join");
1892 }
1893 }
1894
1895 Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1896 }
1897 _ => internal_err!("If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"),
1898 }
1899}
1900
1901#[cfg(test)]
1902pub(crate) mod tests {
1903 use super::*;
1904 use crate::test::{assert_join_metrics, TestMemoryExec};
1905 use crate::{
1906 common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
1907 };
1908
1909 use arrow::compute::SortOptions;
1910 use arrow::datatypes::{DataType, Field};
1911 use datafusion_common::test_util::batches_to_sort_string;
1912 use datafusion_common::{assert_contains, ScalarValue};
1913 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1914 use datafusion_expr::Operator;
1915 use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
1916 use datafusion_physical_expr::{Partitioning, PhysicalExpr};
1917 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1918
1919 use insta::allow_duplicates;
1920 use insta::assert_snapshot;
1921 use rstest::rstest;
1922
1923 fn build_table(
1924 a: (&str, &Vec<i32>),
1925 b: (&str, &Vec<i32>),
1926 c: (&str, &Vec<i32>),
1927 batch_size: Option<usize>,
1928 sorted_column_names: Vec<&str>,
1929 ) -> Arc<dyn ExecutionPlan> {
1930 let batch = build_table_i32(a, b, c);
1931 let schema = batch.schema();
1932
1933 let batches = if let Some(batch_size) = batch_size {
1934 let num_batches = batch.num_rows().div_ceil(batch_size);
1935 (0..num_batches)
1936 .map(|i| {
1937 let start = i * batch_size;
1938 let remaining_rows = batch.num_rows() - start;
1939 batch.slice(start, batch_size.min(remaining_rows))
1940 })
1941 .collect::<Vec<_>>()
1942 } else {
1943 vec![batch]
1944 };
1945
1946 let mut sort_info = vec![];
1947 for name in sorted_column_names {
1948 let index = schema.index_of(name).unwrap();
1949 let sort_expr = PhysicalSortExpr::new(
1950 Arc::new(Column::new(name, index)),
1951 SortOptions::new(false, false),
1952 );
1953 sort_info.push(sort_expr);
1954 }
1955 let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap();
1956 if let Some(ordering) = LexOrdering::new(sort_info) {
1957 source = source.try_with_sort_information(vec![ordering]).unwrap();
1958 }
1959
1960 Arc::new(TestMemoryExec::update_cache(Arc::new(source)))
1961 }
1962
1963 fn build_left_table() -> Arc<dyn ExecutionPlan> {
1964 build_table(
1965 ("a1", &vec![5, 9, 11]),
1966 ("b1", &vec![5, 8, 8]),
1967 ("c1", &vec![50, 90, 110]),
1968 None,
1969 Vec::new(),
1970 )
1971 }
1972
1973 fn build_right_table() -> Arc<dyn ExecutionPlan> {
1974 build_table(
1975 ("a2", &vec![12, 2, 10]),
1976 ("b2", &vec![10, 2, 10]),
1977 ("c2", &vec![40, 80, 100]),
1978 None,
1979 Vec::new(),
1980 )
1981 }
1982
1983 fn prepare_join_filter() -> JoinFilter {
1984 let column_indices = vec![
1985 ColumnIndex {
1986 index: 1,
1987 side: JoinSide::Left,
1988 },
1989 ColumnIndex {
1990 index: 1,
1991 side: JoinSide::Right,
1992 },
1993 ];
1994 let intermediate_schema = Schema::new(vec![
1995 Field::new("x", DataType::Int32, true),
1996 Field::new("x", DataType::Int32, true),
1997 ]);
1998 let left_filter = Arc::new(BinaryExpr::new(
2000 Arc::new(Column::new("x", 0)),
2001 Operator::NotEq,
2002 Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
2003 )) as Arc<dyn PhysicalExpr>;
2004 let right_filter = Arc::new(BinaryExpr::new(
2006 Arc::new(Column::new("x", 1)),
2007 Operator::NotEq,
2008 Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2009 )) as Arc<dyn PhysicalExpr>;
2010 let filter_expression =
2021 Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
2022 as Arc<dyn PhysicalExpr>;
2023
2024 JoinFilter::new(
2025 filter_expression,
2026 column_indices,
2027 Arc::new(intermediate_schema),
2028 )
2029 }
2030
2031 pub(crate) async fn multi_partitioned_join_collect(
2032 left: Arc<dyn ExecutionPlan>,
2033 right: Arc<dyn ExecutionPlan>,
2034 join_type: &JoinType,
2035 join_filter: Option<JoinFilter>,
2036 context: Arc<TaskContext>,
2037 ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
2038 let partition_count = 4;
2039
2040 let right = Arc::new(RepartitionExec::try_new(
2042 right,
2043 Partitioning::RoundRobinBatch(partition_count),
2044 )?) as Arc<dyn ExecutionPlan>;
2045
2046 let nested_loop_join =
2048 NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
2049 let columns = columns(&nested_loop_join.schema());
2050 let mut batches = vec![];
2051 for i in 0..partition_count {
2052 let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
2053 let more_batches = common::collect(stream).await?;
2054 batches.extend(
2055 more_batches
2056 .into_iter()
2057 .inspect(|b| {
2058 assert!(b.num_rows() <= context.session_config().batch_size())
2059 })
2060 .filter(|b| b.num_rows() > 0)
2061 .collect::<Vec<_>>(),
2062 );
2063 }
2064
2065 let metrics = nested_loop_join.metrics().unwrap();
2066
2067 Ok((columns, batches, metrics))
2068 }
2069
2070 fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
2071 let base = TaskContext::default();
2072 let cfg = base.session_config().clone().with_batch_size(batch_size);
2074 Arc::new(base.with_session_config(cfg))
2075 }
2076
2077 #[rstest]
2078 #[tokio::test]
2079 async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2080 let task_ctx = new_task_ctx(batch_size);
2081 dbg!(&batch_size);
2082 let left = build_left_table();
2083 let right = build_right_table();
2084 let filter = prepare_join_filter();
2085 let (columns, batches, metrics) = multi_partitioned_join_collect(
2086 left,
2087 right,
2088 &JoinType::Inner,
2089 Some(filter),
2090 task_ctx,
2091 )
2092 .await?;
2093
2094 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2095 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2096 +----+----+----+----+----+----+
2097 | a1 | b1 | c1 | a2 | b2 | c2 |
2098 +----+----+----+----+----+----+
2099 | 5 | 5 | 50 | 2 | 2 | 80 |
2100 +----+----+----+----+----+----+
2101 "#));
2102
2103 assert_join_metrics!(metrics, 1);
2104
2105 Ok(())
2106 }
2107
2108 #[rstest]
2109 #[tokio::test]
2110 async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2111 let task_ctx = new_task_ctx(batch_size);
2112 let left = build_left_table();
2113 let right = build_right_table();
2114
2115 let filter = prepare_join_filter();
2116 let (columns, batches, metrics) = multi_partitioned_join_collect(
2117 left,
2118 right,
2119 &JoinType::Left,
2120 Some(filter),
2121 task_ctx,
2122 )
2123 .await?;
2124 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2125 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2126 +----+----+-----+----+----+----+
2127 | a1 | b1 | c1 | a2 | b2 | c2 |
2128 +----+----+-----+----+----+----+
2129 | 11 | 8 | 110 | | | |
2130 | 5 | 5 | 50 | 2 | 2 | 80 |
2131 | 9 | 8 | 90 | | | |
2132 +----+----+-----+----+----+----+
2133 "#));
2134
2135 assert_join_metrics!(metrics, 3);
2136
2137 Ok(())
2138 }
2139
2140 #[rstest]
2141 #[tokio::test]
2142 async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2143 let task_ctx = new_task_ctx(batch_size);
2144 let left = build_left_table();
2145 let right = build_right_table();
2146
2147 let filter = prepare_join_filter();
2148 let (columns, batches, metrics) = multi_partitioned_join_collect(
2149 left,
2150 right,
2151 &JoinType::Right,
2152 Some(filter),
2153 task_ctx,
2154 )
2155 .await?;
2156 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2157 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2158 +----+----+----+----+----+-----+
2159 | a1 | b1 | c1 | a2 | b2 | c2 |
2160 +----+----+----+----+----+-----+
2161 | | | | 10 | 10 | 100 |
2162 | | | | 12 | 10 | 40 |
2163 | 5 | 5 | 50 | 2 | 2 | 80 |
2164 +----+----+----+----+----+-----+
2165 "#));
2166
2167 assert_join_metrics!(metrics, 3);
2168
2169 Ok(())
2170 }
2171
2172 #[rstest]
2173 #[tokio::test]
2174 async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2175 let task_ctx = new_task_ctx(batch_size);
2176 let left = build_left_table();
2177 let right = build_right_table();
2178
2179 let filter = prepare_join_filter();
2180 let (columns, batches, metrics) = multi_partitioned_join_collect(
2181 left,
2182 right,
2183 &JoinType::Full,
2184 Some(filter),
2185 task_ctx,
2186 )
2187 .await?;
2188 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2189 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2190 +----+----+-----+----+----+-----+
2191 | a1 | b1 | c1 | a2 | b2 | c2 |
2192 +----+----+-----+----+----+-----+
2193 | | | | 10 | 10 | 100 |
2194 | | | | 12 | 10 | 40 |
2195 | 11 | 8 | 110 | | | |
2196 | 5 | 5 | 50 | 2 | 2 | 80 |
2197 | 9 | 8 | 90 | | | |
2198 +----+----+-----+----+----+-----+
2199 "#));
2200
2201 assert_join_metrics!(metrics, 5);
2202
2203 Ok(())
2204 }
2205
2206 #[rstest]
2207 #[tokio::test]
2208 async fn join_left_semi_with_filter(
2209 #[values(1, 2, 16)] batch_size: usize,
2210 ) -> Result<()> {
2211 let task_ctx = new_task_ctx(batch_size);
2212 let left = build_left_table();
2213 let right = build_right_table();
2214
2215 let filter = prepare_join_filter();
2216 let (columns, batches, metrics) = multi_partitioned_join_collect(
2217 left,
2218 right,
2219 &JoinType::LeftSemi,
2220 Some(filter),
2221 task_ctx,
2222 )
2223 .await?;
2224 assert_eq!(columns, vec!["a1", "b1", "c1"]);
2225 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2226 +----+----+----+
2227 | a1 | b1 | c1 |
2228 +----+----+----+
2229 | 5 | 5 | 50 |
2230 +----+----+----+
2231 "#));
2232
2233 assert_join_metrics!(metrics, 1);
2234
2235 Ok(())
2236 }
2237
2238 #[rstest]
2239 #[tokio::test]
2240 async fn join_left_anti_with_filter(
2241 #[values(1, 2, 16)] batch_size: usize,
2242 ) -> Result<()> {
2243 let task_ctx = new_task_ctx(batch_size);
2244 let left = build_left_table();
2245 let right = build_right_table();
2246
2247 let filter = prepare_join_filter();
2248 let (columns, batches, metrics) = multi_partitioned_join_collect(
2249 left,
2250 right,
2251 &JoinType::LeftAnti,
2252 Some(filter),
2253 task_ctx,
2254 )
2255 .await?;
2256 assert_eq!(columns, vec!["a1", "b1", "c1"]);
2257 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2258 +----+----+-----+
2259 | a1 | b1 | c1 |
2260 +----+----+-----+
2261 | 11 | 8 | 110 |
2262 | 9 | 8 | 90 |
2263 +----+----+-----+
2264 "#));
2265
2266 assert_join_metrics!(metrics, 2);
2267
2268 Ok(())
2269 }
2270
2271 #[rstest]
2272 #[tokio::test]
2273 async fn join_right_semi_with_filter(
2274 #[values(1, 2, 16)] batch_size: usize,
2275 ) -> Result<()> {
2276 let task_ctx = new_task_ctx(batch_size);
2277 let left = build_left_table();
2278 let right = build_right_table();
2279
2280 let filter = prepare_join_filter();
2281 let (columns, batches, metrics) = multi_partitioned_join_collect(
2282 left,
2283 right,
2284 &JoinType::RightSemi,
2285 Some(filter),
2286 task_ctx,
2287 )
2288 .await?;
2289 assert_eq!(columns, vec!["a2", "b2", "c2"]);
2290 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2291 +----+----+----+
2292 | a2 | b2 | c2 |
2293 +----+----+----+
2294 | 2 | 2 | 80 |
2295 +----+----+----+
2296 "#));
2297
2298 assert_join_metrics!(metrics, 1);
2299
2300 Ok(())
2301 }
2302
2303 #[rstest]
2304 #[tokio::test]
2305 async fn join_right_anti_with_filter(
2306 #[values(1, 2, 16)] batch_size: usize,
2307 ) -> Result<()> {
2308 let task_ctx = new_task_ctx(batch_size);
2309 let left = build_left_table();
2310 let right = build_right_table();
2311
2312 let filter = prepare_join_filter();
2313 let (columns, batches, metrics) = multi_partitioned_join_collect(
2314 left,
2315 right,
2316 &JoinType::RightAnti,
2317 Some(filter),
2318 task_ctx,
2319 )
2320 .await?;
2321 assert_eq!(columns, vec!["a2", "b2", "c2"]);
2322 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2323 +----+----+-----+
2324 | a2 | b2 | c2 |
2325 +----+----+-----+
2326 | 10 | 10 | 100 |
2327 | 12 | 10 | 40 |
2328 +----+----+-----+
2329 "#));
2330
2331 assert_join_metrics!(metrics, 2);
2332
2333 Ok(())
2334 }
2335
2336 #[rstest]
2337 #[tokio::test]
2338 async fn join_left_mark_with_filter(
2339 #[values(1, 2, 16)] batch_size: usize,
2340 ) -> Result<()> {
2341 let task_ctx = new_task_ctx(batch_size);
2342 let left = build_left_table();
2343 let right = build_right_table();
2344
2345 let filter = prepare_join_filter();
2346 let (columns, batches, metrics) = multi_partitioned_join_collect(
2347 left,
2348 right,
2349 &JoinType::LeftMark,
2350 Some(filter),
2351 task_ctx,
2352 )
2353 .await?;
2354 assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
2355 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2356 +----+----+-----+-------+
2357 | a1 | b1 | c1 | mark |
2358 +----+----+-----+-------+
2359 | 11 | 8 | 110 | false |
2360 | 5 | 5 | 50 | true |
2361 | 9 | 8 | 90 | false |
2362 +----+----+-----+-------+
2363 "#));
2364
2365 assert_join_metrics!(metrics, 3);
2366
2367 Ok(())
2368 }
2369
2370 #[rstest]
2371 #[tokio::test]
2372 async fn join_right_mark_with_filter(
2373 #[values(1, 2, 16)] batch_size: usize,
2374 ) -> Result<()> {
2375 let task_ctx = new_task_ctx(batch_size);
2376 let left = build_left_table();
2377 let right = build_right_table();
2378
2379 let filter = prepare_join_filter();
2380 let (columns, batches, metrics) = multi_partitioned_join_collect(
2381 left,
2382 right,
2383 &JoinType::RightMark,
2384 Some(filter),
2385 task_ctx,
2386 )
2387 .await?;
2388 assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
2389
2390 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2391 +----+----+-----+-------+
2392 | a2 | b2 | c2 | mark |
2393 +----+----+-----+-------+
2394 | 10 | 10 | 100 | false |
2395 | 12 | 10 | 40 | false |
2396 | 2 | 2 | 80 | true |
2397 +----+----+-----+-------+
2398 "#));
2399
2400 assert_join_metrics!(metrics, 3);
2401
2402 Ok(())
2403 }
2404
2405 #[tokio::test]
2406 async fn test_overallocation() -> Result<()> {
2407 let left = build_table(
2408 ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2409 ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2410 ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2411 None,
2412 Vec::new(),
2413 );
2414 let right = build_table(
2415 ("a2", &vec![10, 11]),
2416 ("b2", &vec![12, 13]),
2417 ("c2", &vec![14, 15]),
2418 None,
2419 Vec::new(),
2420 );
2421 let filter = prepare_join_filter();
2422
2423 let join_types = vec![
2424 JoinType::Inner,
2425 JoinType::Left,
2426 JoinType::Right,
2427 JoinType::Full,
2428 JoinType::LeftSemi,
2429 JoinType::LeftAnti,
2430 JoinType::LeftMark,
2431 JoinType::RightSemi,
2432 JoinType::RightAnti,
2433 JoinType::RightMark,
2434 ];
2435
2436 for join_type in join_types {
2437 let runtime = RuntimeEnvBuilder::new()
2438 .with_memory_limit(100, 1.0)
2439 .build_arc()?;
2440 let task_ctx = TaskContext::default().with_runtime(runtime);
2441 let task_ctx = Arc::new(task_ctx);
2442
2443 let err = multi_partitioned_join_collect(
2444 Arc::clone(&left),
2445 Arc::clone(&right),
2446 &join_type,
2447 Some(filter.clone()),
2448 task_ctx,
2449 )
2450 .await
2451 .unwrap_err();
2452
2453 assert_contains!(
2454 err.to_string(),
2455 "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]"
2456 );
2457 }
2458
2459 Ok(())
2460 }
2461
2462 fn columns(schema: &Schema) -> Vec<String> {
2464 schema.fields().iter().map(|f| f.name().clone()).collect()
2465 }
2466}