1use std::any::Any;
21use std::fmt::Formatter;
22use std::ops::{BitOr, ControlFlow};
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::task::Poll;
26
27use super::utils::{
28 asymmetric_join_output_partitioning, need_produce_result_in_final,
29 reorder_output_after_swap, swap_join_projection,
30};
31use crate::common::can_project;
32use crate::execution_plan::{boundedness_from_children, EmissionType};
33use crate::joins::utils::{
34 build_join_schema, check_join_is_valid, estimate_join_statistics,
35 need_produce_right_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
36 OnceAsync, OnceFut,
37};
38use crate::joins::SharedBitmapBuilder;
39use crate::metrics::{
40 Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, MetricsSet, RatioMetrics,
41};
42use crate::projection::{
43 try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData,
44 ProjectionExec,
45};
46use crate::{
47 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
48 PlanProperties, RecordBatchStream, SendableRecordBatchStream,
49};
50
51use arrow::array::{
52 new_null_array, Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions,
53 UInt64Array,
54};
55use arrow::buffer::BooleanBuffer;
56use arrow::compute::{
57 concat_batches, filter, filter_record_batch, not, take, BatchCoalescer,
58};
59use arrow::datatypes::{Schema, SchemaRef};
60use arrow::record_batch::RecordBatch;
61use arrow_schema::DataType;
62use datafusion_common::cast::as_boolean_array;
63use datafusion_common::{
64 arrow_err, internal_datafusion_err, internal_err, project_schema,
65 unwrap_or_internal_err, DataFusionError, JoinSide, Result, ScalarValue, Statistics,
66};
67use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
68use datafusion_execution::TaskContext;
69use datafusion_expr::JoinType;
70use datafusion_physical_expr::equivalence::{
71 join_equivalence_properties, ProjectionMapping,
72};
73
74use futures::{Stream, StreamExt, TryStreamExt};
75use log::debug;
76use parking_lot::Mutex;
77
78#[allow(rustdoc::private_intra_doc_links)]
79#[derive(Debug)]
173pub struct NestedLoopJoinExec {
174 pub(crate) left: Arc<dyn ExecutionPlan>,
176 pub(crate) right: Arc<dyn ExecutionPlan>,
178 pub(crate) filter: Option<JoinFilter>,
180 pub(crate) join_type: JoinType,
182 join_schema: SchemaRef,
185 build_side_data: OnceAsync<JoinLeftData>,
192 column_indices: Vec<ColumnIndex>,
194 projection: Option<Vec<usize>>,
196
197 metrics: ExecutionPlanMetricsSet,
199 cache: PlanProperties,
201}
202
203impl NestedLoopJoinExec {
204 pub fn try_new(
206 left: Arc<dyn ExecutionPlan>,
207 right: Arc<dyn ExecutionPlan>,
208 filter: Option<JoinFilter>,
209 join_type: &JoinType,
210 projection: Option<Vec<usize>>,
211 ) -> Result<Self> {
212 let left_schema = left.schema();
213 let right_schema = right.schema();
214 check_join_is_valid(&left_schema, &right_schema, &[])?;
215 let (join_schema, column_indices) =
216 build_join_schema(&left_schema, &right_schema, join_type);
217 let join_schema = Arc::new(join_schema);
218 let cache = Self::compute_properties(
219 &left,
220 &right,
221 Arc::clone(&join_schema),
222 *join_type,
223 projection.as_ref(),
224 )?;
225
226 Ok(NestedLoopJoinExec {
227 left,
228 right,
229 filter,
230 join_type: *join_type,
231 join_schema,
232 build_side_data: Default::default(),
233 column_indices,
234 projection,
235 metrics: Default::default(),
236 cache,
237 })
238 }
239
240 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
242 &self.left
243 }
244
245 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
247 &self.right
248 }
249
250 pub fn filter(&self) -> Option<&JoinFilter> {
252 self.filter.as_ref()
253 }
254
255 pub fn join_type(&self) -> &JoinType {
257 &self.join_type
258 }
259
260 pub fn projection(&self) -> Option<&Vec<usize>> {
261 self.projection.as_ref()
262 }
263
264 fn compute_properties(
266 left: &Arc<dyn ExecutionPlan>,
267 right: &Arc<dyn ExecutionPlan>,
268 schema: SchemaRef,
269 join_type: JoinType,
270 projection: Option<&Vec<usize>>,
271 ) -> Result<PlanProperties> {
272 let mut eq_properties = join_equivalence_properties(
274 left.equivalence_properties().clone(),
275 right.equivalence_properties().clone(),
276 &join_type,
277 Arc::clone(&schema),
278 &Self::maintains_input_order(join_type),
279 None,
280 &[],
282 )?;
283
284 let mut output_partitioning =
285 asymmetric_join_output_partitioning(left, right, &join_type)?;
286
287 let emission_type = if left.boundedness().is_unbounded() {
288 EmissionType::Final
289 } else if right.pipeline_behavior() == EmissionType::Incremental {
290 match join_type {
291 JoinType::Inner
294 | JoinType::LeftSemi
295 | JoinType::RightSemi
296 | JoinType::Right
297 | JoinType::RightAnti
298 | JoinType::RightMark => EmissionType::Incremental,
299 JoinType::Left
302 | JoinType::LeftAnti
303 | JoinType::LeftMark
304 | JoinType::Full => EmissionType::Both,
305 }
306 } else {
307 right.pipeline_behavior()
308 };
309
310 if let Some(projection) = projection {
311 let projection_mapping =
313 ProjectionMapping::from_indices(projection, &schema)?;
314 let out_schema = project_schema(&schema, Some(projection))?;
315 output_partitioning =
316 output_partitioning.project(&projection_mapping, &eq_properties);
317 eq_properties = eq_properties.project(&projection_mapping, out_schema);
318 }
319
320 Ok(PlanProperties::new(
321 eq_properties,
322 output_partitioning,
323 emission_type,
324 boundedness_from_children([left, right]),
325 ))
326 }
327
328 fn maintains_input_order(_join_type: JoinType) -> Vec<bool> {
330 vec![false, false]
331 }
332
333 pub fn contains_projection(&self) -> bool {
334 self.projection.is_some()
335 }
336
337 pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
338 can_project(&self.schema(), projection.as_ref())?;
340 let projection = match projection {
341 Some(projection) => match &self.projection {
342 Some(p) => Some(projection.iter().map(|i| p[*i]).collect()),
343 None => Some(projection),
344 },
345 None => None,
346 };
347 Self::try_new(
348 Arc::clone(&self.left),
349 Arc::clone(&self.right),
350 self.filter.clone(),
351 &self.join_type,
352 projection,
353 )
354 }
355
356 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
365 let left = self.left();
366 let right = self.right();
367 let new_join = NestedLoopJoinExec::try_new(
368 Arc::clone(right),
369 Arc::clone(left),
370 self.filter().map(JoinFilter::swap),
371 &self.join_type().swap(),
372 swap_join_projection(
373 left.schema().fields().len(),
374 right.schema().fields().len(),
375 self.projection.as_ref(),
376 self.join_type(),
377 ),
378 )?;
379
380 let plan: Arc<dyn ExecutionPlan> = if matches!(
383 self.join_type(),
384 JoinType::LeftSemi
385 | JoinType::RightSemi
386 | JoinType::LeftAnti
387 | JoinType::RightAnti
388 | JoinType::LeftMark
389 | JoinType::RightMark
390 ) || self.projection.is_some()
391 {
392 Arc::new(new_join)
393 } else {
394 reorder_output_after_swap(
395 Arc::new(new_join),
396 &self.left().schema(),
397 &self.right().schema(),
398 )?
399 };
400
401 Ok(plan)
402 }
403}
404
405impl DisplayAs for NestedLoopJoinExec {
406 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
407 match t {
408 DisplayFormatType::Default | DisplayFormatType::Verbose => {
409 let display_filter = self.filter.as_ref().map_or_else(
410 || "".to_string(),
411 |f| format!(", filter={}", f.expression()),
412 );
413 let display_projections = if self.contains_projection() {
414 format!(
415 ", projection=[{}]",
416 self.projection
417 .as_ref()
418 .unwrap()
419 .iter()
420 .map(|index| format!(
421 "{}@{}",
422 self.join_schema.fields().get(*index).unwrap().name(),
423 index
424 ))
425 .collect::<Vec<_>>()
426 .join(", ")
427 )
428 } else {
429 "".to_string()
430 };
431 write!(
432 f,
433 "NestedLoopJoinExec: join_type={:?}{}{}",
434 self.join_type, display_filter, display_projections
435 )
436 }
437 DisplayFormatType::TreeRender => {
438 if *self.join_type() != JoinType::Inner {
439 writeln!(f, "join_type={:?}", self.join_type)
440 } else {
441 Ok(())
442 }
443 }
444 }
445 }
446}
447
448impl ExecutionPlan for NestedLoopJoinExec {
449 fn name(&self) -> &'static str {
450 "NestedLoopJoinExec"
451 }
452
453 fn as_any(&self) -> &dyn Any {
454 self
455 }
456
457 fn properties(&self) -> &PlanProperties {
458 &self.cache
459 }
460
461 fn required_input_distribution(&self) -> Vec<Distribution> {
462 vec![
463 Distribution::SinglePartition,
464 Distribution::UnspecifiedDistribution,
465 ]
466 }
467
468 fn maintains_input_order(&self) -> Vec<bool> {
469 Self::maintains_input_order(self.join_type)
470 }
471
472 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
473 vec![&self.left, &self.right]
474 }
475
476 fn with_new_children(
477 self: Arc<Self>,
478 children: Vec<Arc<dyn ExecutionPlan>>,
479 ) -> Result<Arc<dyn ExecutionPlan>> {
480 Ok(Arc::new(NestedLoopJoinExec::try_new(
481 Arc::clone(&children[0]),
482 Arc::clone(&children[1]),
483 self.filter.clone(),
484 &self.join_type,
485 self.projection.clone(),
486 )?))
487 }
488
489 fn execute(
490 &self,
491 partition: usize,
492 context: Arc<TaskContext>,
493 ) -> Result<SendableRecordBatchStream> {
494 if self.left.output_partitioning().partition_count() != 1 {
495 return internal_err!(
496 "Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
497 consider using CoalescePartitionsExec or the EnforceDistribution rule"
498 );
499 }
500
501 let metrics = NestedLoopJoinMetrics::new(&self.metrics, partition);
502
503 let load_reservation =
505 MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
506 .register(context.memory_pool());
507
508 let build_side_data = self.build_side_data.try_once(|| {
509 let stream = self.left.execute(0, Arc::clone(&context))?;
510
511 Ok(collect_left_input(
512 stream,
513 metrics.join_metrics.clone(),
514 load_reservation,
515 need_produce_result_in_final(self.join_type),
516 self.right().output_partitioning().partition_count(),
517 ))
518 })?;
519
520 let batch_size = context.session_config().batch_size();
521
522 let probe_side_data = self.right.execute(partition, context)?;
523
524 let column_indices_after_projection = match &self.projection {
526 Some(projection) => projection
527 .iter()
528 .map(|i| self.column_indices[*i].clone())
529 .collect(),
530 None => self.column_indices.clone(),
531 };
532
533 Ok(Box::pin(NestedLoopJoinStream::new(
534 self.schema(),
535 self.filter.clone(),
536 self.join_type,
537 probe_side_data,
538 build_side_data,
539 column_indices_after_projection,
540 metrics,
541 batch_size,
542 )))
543 }
544
545 fn metrics(&self) -> Option<MetricsSet> {
546 Some(self.metrics.clone_inner())
547 }
548
549 fn statistics(&self) -> Result<Statistics> {
550 self.partition_statistics(None)
551 }
552
553 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
554 if partition.is_some() {
555 return Ok(Statistics::new_unknown(&self.schema()));
556 }
557 estimate_join_statistics(
558 self.left.partition_statistics(None)?,
559 self.right.partition_statistics(None)?,
560 vec![],
561 &self.join_type,
562 &self.schema(),
563 )
564 }
565
566 fn try_swapping_with_projection(
570 &self,
571 projection: &ProjectionExec,
572 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
573 if self.contains_projection() {
575 return Ok(None);
576 }
577
578 if let Some(JoinData {
579 projected_left_child,
580 projected_right_child,
581 join_filter,
582 ..
583 }) = try_pushdown_through_join(
584 projection,
585 self.left(),
586 self.right(),
587 &[],
588 self.schema(),
589 self.filter(),
590 )? {
591 Ok(Some(Arc::new(NestedLoopJoinExec::try_new(
592 Arc::new(projected_left_child),
593 Arc::new(projected_right_child),
594 join_filter,
595 self.join_type(),
596 None,
598 )?)))
599 } else {
600 try_embed_projection(projection, self)
601 }
602 }
603}
604
605impl EmbeddedProjection for NestedLoopJoinExec {
606 fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> {
607 self.with_projection(projection)
608 }
609}
610
611pub(crate) struct JoinLeftData {
613 batch: RecordBatch,
615 bitmap: SharedBitmapBuilder,
617 probe_threads_counter: AtomicUsize,
619 #[expect(dead_code)]
623 reservation: MemoryReservation,
624}
625
626impl JoinLeftData {
627 pub(crate) fn new(
628 batch: RecordBatch,
629 bitmap: SharedBitmapBuilder,
630 probe_threads_counter: AtomicUsize,
631 reservation: MemoryReservation,
632 ) -> Self {
633 Self {
634 batch,
635 bitmap,
636 probe_threads_counter,
637 reservation,
638 }
639 }
640
641 pub(crate) fn batch(&self) -> &RecordBatch {
642 &self.batch
643 }
644
645 pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder {
646 &self.bitmap
647 }
648
649 pub(crate) fn report_probe_completed(&self) -> bool {
652 self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
653 }
654}
655
656async fn collect_left_input(
658 stream: SendableRecordBatchStream,
659 join_metrics: BuildProbeJoinMetrics,
660 reservation: MemoryReservation,
661 with_visited_left_side: bool,
662 probe_threads_count: usize,
663) -> Result<JoinLeftData> {
664 let schema = stream.schema();
665
666 let (batches, metrics, mut reservation) = stream
668 .try_fold(
669 (Vec::new(), join_metrics, reservation),
670 |(mut batches, metrics, mut reservation), batch| async {
671 let batch_size = batch.get_array_memory_size();
672 reservation.try_grow(batch_size)?;
674 metrics.build_mem_used.add(batch_size);
676 metrics.build_input_batches.add(1);
677 metrics.build_input_rows.add(batch.num_rows());
678 batches.push(batch);
680 Ok((batches, metrics, reservation))
681 },
682 )
683 .await?;
684
685 let merged_batch = concat_batches(&schema, &batches)?;
686
687 let visited_left_side = if with_visited_left_side {
689 let n_rows = merged_batch.num_rows();
690 let buffer_size = n_rows.div_ceil(8);
691 reservation.try_grow(buffer_size)?;
692 metrics.build_mem_used.add(buffer_size);
693
694 let mut buffer = BooleanBufferBuilder::new(n_rows);
695 buffer.append_n(n_rows, false);
696 buffer
697 } else {
698 BooleanBufferBuilder::new(0)
699 };
700
701 Ok(JoinLeftData::new(
702 merged_batch,
703 Mutex::new(visited_left_side),
704 AtomicUsize::new(probe_threads_count),
705 reservation,
706 ))
707}
708
709#[derive(Debug, Clone, Copy)]
712enum NLJState {
713 BufferingLeft,
714 FetchingRight,
715 ProbeRight,
716 EmitRightUnmatched,
717 EmitLeftUnmatched,
718 Done,
719}
720pub(crate) struct NestedLoopJoinStream {
721 pub(crate) output_schema: Arc<Schema>,
732 pub(crate) join_filter: Option<JoinFilter>,
734 pub(crate) join_type: JoinType,
736 pub(crate) right_data: SendableRecordBatchStream,
738 pub(crate) left_data: OnceFut<JoinLeftData>,
740 pub(crate) column_indices: Vec<ColumnIndex>,
753 pub(crate) metrics: NestedLoopJoinMetrics,
755
756 batch_size: usize,
758
759 should_track_unmatched_right: bool,
761
762 state: NLJState,
768 output_buffer: Box<BatchCoalescer>,
771 handled_empty_output: bool,
773
774 buffered_left_data: Option<Arc<JoinLeftData>>,
778 left_probe_idx: usize,
780 left_emit_idx: usize,
782 left_exhausted: bool,
785 #[allow(dead_code)]
788 left_buffered_in_one_pass: bool,
789
790 current_right_batch: Option<RecordBatch>,
794 current_right_batch_matched: Option<BooleanArray>,
797}
798
799pub(crate) struct NestedLoopJoinMetrics {
800 pub(crate) join_metrics: BuildProbeJoinMetrics,
802 pub(crate) selectivity: RatioMetrics,
804}
805
806impl NestedLoopJoinMetrics {
807 pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
808 Self {
809 join_metrics: BuildProbeJoinMetrics::new(partition, metrics),
810 selectivity: MetricBuilder::new(metrics)
811 .with_type(MetricType::SUMMARY)
812 .ratio_metrics("selectivity", partition),
813 }
814 }
815}
816
817impl Stream for NestedLoopJoinStream {
818 type Item = Result<RecordBatch>;
819
820 fn poll_next(
851 mut self: std::pin::Pin<&mut Self>,
852 cx: &mut std::task::Context<'_>,
853 ) -> Poll<Option<Self::Item>> {
854 loop {
855 match self.state {
856 NLJState::BufferingLeft => {
862 debug!("[NLJState] Entering: {:?}", self.state);
863 let build_metric = self.metrics.join_metrics.build_time.clone();
868 let _build_timer = build_metric.timer();
869
870 match self.handle_buffering_left(cx) {
871 ControlFlow::Continue(()) => continue,
872 ControlFlow::Break(poll) => return poll,
873 }
874 }
875
876 NLJState::FetchingRight => {
899 debug!("[NLJState] Entering: {:?}", self.state);
900 let join_metric = self.metrics.join_metrics.join_time.clone();
902 let _join_timer = join_metric.timer();
903
904 match self.handle_fetching_right(cx) {
905 ControlFlow::Continue(()) => continue,
906 ControlFlow::Break(poll) => return poll,
907 }
908 }
909
910 NLJState::ProbeRight => {
925 debug!("[NLJState] Entering: {:?}", self.state);
926
927 let join_metric = self.metrics.join_metrics.join_time.clone();
929 let _join_timer = join_metric.timer();
930
931 match self.handle_probe_right() {
932 ControlFlow::Continue(()) => continue,
933 ControlFlow::Break(poll) => {
934 return self.metrics.join_metrics.baseline.record_poll(poll)
935 }
936 }
937 }
938
939 NLJState::EmitRightUnmatched => {
946 debug!("[NLJState] Entering: {:?}", self.state);
947
948 let join_metric = self.metrics.join_metrics.join_time.clone();
950 let _join_timer = join_metric.timer();
951
952 match self.handle_emit_right_unmatched() {
953 ControlFlow::Continue(()) => continue,
954 ControlFlow::Break(poll) => {
955 return self.metrics.join_metrics.baseline.record_poll(poll)
956 }
957 }
958 }
959
960 NLJState::EmitLeftUnmatched => {
976 debug!("[NLJState] Entering: {:?}", self.state);
977
978 let join_metric = self.metrics.join_metrics.join_time.clone();
980 let _join_timer = join_metric.timer();
981
982 match self.handle_emit_left_unmatched() {
983 ControlFlow::Continue(()) => continue,
984 ControlFlow::Break(poll) => {
985 return self.metrics.join_metrics.baseline.record_poll(poll)
986 }
987 }
988 }
989
990 NLJState::Done => {
992 debug!("[NLJState] Entering: {:?}", self.state);
993
994 let join_metric = self.metrics.join_metrics.join_time.clone();
996 let _join_timer = join_metric.timer();
997 let poll = self.handle_done();
1001 return self.metrics.join_metrics.baseline.record_poll(poll);
1002 }
1003 }
1004 }
1005 }
1006}
1007
1008impl RecordBatchStream for NestedLoopJoinStream {
1009 fn schema(&self) -> SchemaRef {
1010 Arc::clone(&self.output_schema)
1011 }
1012}
1013
1014impl NestedLoopJoinStream {
1015 #[allow(clippy::too_many_arguments)]
1016 pub(crate) fn new(
1017 schema: Arc<Schema>,
1018 filter: Option<JoinFilter>,
1019 join_type: JoinType,
1020 right_data: SendableRecordBatchStream,
1021 left_data: OnceFut<JoinLeftData>,
1022 column_indices: Vec<ColumnIndex>,
1023 metrics: NestedLoopJoinMetrics,
1024 batch_size: usize,
1025 ) -> Self {
1026 Self {
1027 output_schema: Arc::clone(&schema),
1028 join_filter: filter,
1029 join_type,
1030 right_data,
1031 column_indices,
1032 left_data,
1033 metrics,
1034 buffered_left_data: None,
1035 output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)),
1036 batch_size,
1037 current_right_batch: None,
1038 current_right_batch_matched: None,
1039 state: NLJState::BufferingLeft,
1040 left_probe_idx: 0,
1041 left_emit_idx: 0,
1042 left_exhausted: false,
1043 left_buffered_in_one_pass: true,
1044 handled_empty_output: false,
1045 should_track_unmatched_right: need_produce_right_in_final(join_type),
1046 }
1047 }
1048
1049 fn handle_buffering_left(
1053 &mut self,
1054 cx: &mut std::task::Context<'_>,
1055 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1056 match self.left_data.get_shared(cx) {
1057 Poll::Ready(Ok(left_data)) => {
1058 self.buffered_left_data = Some(left_data);
1059 self.left_exhausted = true;
1061 self.state = NLJState::FetchingRight;
1062 ControlFlow::Continue(())
1064 }
1065 Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1066 Poll::Pending => ControlFlow::Break(Poll::Pending),
1067 }
1068 }
1069
1070 fn handle_fetching_right(
1072 &mut self,
1073 cx: &mut std::task::Context<'_>,
1074 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1075 match self.right_data.poll_next_unpin(cx) {
1076 Poll::Ready(result) => match result {
1077 Some(Ok(right_batch)) => {
1078 let right_batch_size = right_batch.num_rows();
1080 self.metrics.join_metrics.input_rows.add(right_batch_size);
1081 self.metrics.join_metrics.input_batches.add(1);
1082
1083 if right_batch_size == 0 {
1085 return ControlFlow::Continue(());
1086 }
1087
1088 self.current_right_batch = Some(right_batch);
1089
1090 if self.should_track_unmatched_right {
1092 let zeroed_buf = BooleanBuffer::new_unset(right_batch_size);
1093 self.current_right_batch_matched =
1094 Some(BooleanArray::new(zeroed_buf, None));
1095 }
1096
1097 self.left_probe_idx = 0;
1098 self.state = NLJState::ProbeRight;
1099 ControlFlow::Continue(())
1100 }
1101 Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1102 None => {
1103 self.state = NLJState::EmitLeftUnmatched;
1105 ControlFlow::Continue(())
1106 }
1107 },
1108 Poll::Pending => ControlFlow::Break(Poll::Pending),
1109 }
1110 }
1111
1112 fn handle_probe_right(&mut self) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1114 if let Some(poll) = self.maybe_flush_ready_batch() {
1116 return ControlFlow::Break(poll);
1117 }
1118
1119 match self.process_probe_batch() {
1121 Ok(true) => ControlFlow::Continue(()),
1125 Ok(false) => {
1129 self.left_probe_idx = 0;
1131
1132 if let (Ok(left_data), Some(right_batch)) =
1135 (self.get_left_data(), self.current_right_batch.as_ref())
1136 {
1137 let left_rows = left_data.batch().num_rows();
1138 let right_rows = right_batch.num_rows();
1139 self.metrics.selectivity.add_total(left_rows * right_rows);
1140 }
1141
1142 if self.should_track_unmatched_right {
1143 debug_assert!(
1144 self.current_right_batch_matched.is_some(),
1145 "If it's required to track matched rows in the right input, the right bitmap must be present"
1146 );
1147 self.state = NLJState::EmitRightUnmatched;
1148 } else {
1149 self.current_right_batch = None;
1150 self.state = NLJState::FetchingRight;
1151 }
1152 ControlFlow::Continue(())
1153 }
1154 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1155 }
1156 }
1157
1158 fn handle_emit_right_unmatched(
1160 &mut self,
1161 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1162 if let Some(poll) = self.maybe_flush_ready_batch() {
1164 return ControlFlow::Break(poll);
1165 }
1166
1167 debug_assert!(
1168 self.current_right_batch_matched.is_some()
1169 && self.current_right_batch.is_some(),
1170 "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present"
1171 );
1172 match self.process_right_unmatched() {
1174 Ok(Some(batch)) => {
1175 match self.output_buffer.push_batch(batch) {
1176 Ok(()) => {
1177 debug_assert!(self.current_right_batch.is_none());
1180 self.state = NLJState::FetchingRight;
1181 ControlFlow::Continue(())
1182 }
1183 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1184 }
1185 }
1186 Ok(None) => {
1187 debug_assert!(self.current_right_batch.is_none());
1190 self.state = NLJState::FetchingRight;
1191 ControlFlow::Continue(())
1192 }
1193 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1194 }
1195 }
1196
1197 fn handle_emit_left_unmatched(
1199 &mut self,
1200 ) -> ControlFlow<Poll<Option<Result<RecordBatch>>>> {
1201 if let Some(poll) = self.maybe_flush_ready_batch() {
1203 return ControlFlow::Break(poll);
1204 }
1205
1206 match self.process_left_unmatched() {
1208 Ok(true) => ControlFlow::Continue(()),
1211 Ok(false) => match self.output_buffer.finish_buffered_batch() {
1214 Ok(()) => {
1215 self.state = NLJState::Done;
1216 ControlFlow::Continue(())
1217 }
1218 Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))),
1219 },
1220 Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))),
1221 }
1222 }
1223
1224 fn handle_done(&mut self) -> Poll<Option<Result<RecordBatch>>> {
1226 if let Some(poll) = self.maybe_flush_ready_batch() {
1228 return poll;
1229 }
1230
1231 if !self.handled_empty_output {
1237 let zero_count = Count::new();
1238 if *self.metrics.join_metrics.baseline.output_rows() == zero_count {
1239 let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema));
1240 self.handled_empty_output = true;
1241 return Poll::Ready(Some(Ok(empty_batch)));
1242 }
1243 }
1244
1245 Poll::Ready(None)
1246 }
1247
1248 fn process_probe_batch(&mut self) -> Result<bool> {
1255 let left_data = Arc::clone(self.get_left_data()?);
1256 let right_batch = self
1257 .current_right_batch
1258 .as_ref()
1259 .ok_or_else(|| internal_datafusion_err!("Right batch should be available"))?
1260 .clone();
1261
1262 if self.left_probe_idx >= left_data.batch().num_rows() {
1264 return Ok(false);
1265 }
1266
1267 let l_idx = self.left_probe_idx;
1273 let join_batch =
1274 self.process_single_left_row_join(&left_data, &right_batch, l_idx)?;
1275
1276 if let Some(batch) = join_batch {
1277 self.output_buffer.push_batch(batch)?;
1278 }
1279
1280 self.left_probe_idx += 1;
1284
1285 Ok(true)
1287 }
1288
1289 fn process_single_left_row_join(
1292 &mut self,
1293 left_data: &JoinLeftData,
1294 right_batch: &RecordBatch,
1295 l_index: usize,
1296 ) -> Result<Option<RecordBatch>> {
1297 let right_row_count = right_batch.num_rows();
1298 if right_row_count == 0 {
1299 return Ok(None);
1300 }
1301
1302 let cur_right_bitmap = if let Some(filter) = &self.join_filter {
1303 apply_filter_to_row_join_batch(
1304 left_data.batch(),
1305 l_index,
1306 right_batch,
1307 filter,
1308 )?
1309 } else {
1310 BooleanArray::from(vec![true; right_row_count])
1311 };
1312
1313 self.update_matched_bitmap(l_index, &cur_right_bitmap)?;
1314
1315 if matches!(
1318 self.join_type,
1319 JoinType::LeftAnti
1320 | JoinType::LeftSemi
1321 | JoinType::LeftMark
1322 | JoinType::RightAnti
1323 | JoinType::RightMark
1324 | JoinType::RightSemi
1325 ) {
1326 return Ok(None);
1327 }
1328
1329 if cur_right_bitmap.true_count() == 0 {
1330 Ok(None)
1332 } else {
1333 let join_batch = build_row_join_batch(
1335 &self.output_schema,
1336 left_data.batch(),
1337 l_index,
1338 right_batch,
1339 Some(cur_right_bitmap),
1340 &self.column_indices,
1341 JoinSide::Left,
1342 )?;
1343 Ok(join_batch)
1344 }
1345 }
1346
1347 fn process_left_unmatched(&mut self) -> Result<bool> {
1351 let left_data = self.get_left_data()?;
1352 let left_batch = left_data.batch();
1353
1354 let join_type_no_produce_left = !need_produce_result_in_final(self.join_type);
1360 let handled_by_other_partition =
1362 self.left_emit_idx == 0 && !left_data.report_probe_completed();
1363 let finished = self.left_emit_idx >= left_batch.num_rows();
1365
1366 if join_type_no_produce_left || handled_by_other_partition || finished {
1367 return Ok(false);
1368 }
1369
1370 let start_idx = self.left_emit_idx;
1375 let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows());
1376
1377 if let Some(batch) =
1378 self.process_left_unmatched_range(left_data, start_idx, end_idx)?
1379 {
1380 self.output_buffer.push_batch(batch)?;
1381 }
1382
1383 self.left_emit_idx = end_idx;
1385
1386 Ok(true)
1388 }
1389
1390 fn process_left_unmatched_range(
1403 &self,
1404 left_data: &JoinLeftData,
1405 start_idx: usize,
1406 end_idx: usize,
1407 ) -> Result<Option<RecordBatch>> {
1408 if start_idx == end_idx {
1409 return Ok(None);
1410 }
1411
1412 let left_batch = left_data.batch();
1415 let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx);
1416
1417 let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx);
1419 bitmap_sliced.append_n(end_idx - start_idx, false);
1420 let bitmap = left_data.bitmap().lock();
1421 for i in start_idx..end_idx {
1422 assert!(
1423 i - start_idx < bitmap_sliced.capacity(),
1424 "DBG: {start_idx}, {end_idx}"
1425 );
1426 bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i));
1427 }
1428 let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None);
1429
1430 build_unmatched_batch(
1431 Arc::clone(&self.output_schema),
1432 &left_batch_sliced,
1433 bitmap_sliced,
1434 self.right_data.schema(),
1435 &self.column_indices,
1436 self.join_type,
1437 JoinSide::Left,
1438 )
1439 }
1440
1441 fn process_right_unmatched(&mut self) -> Result<Option<RecordBatch>> {
1444 let right_batch_bitmap: BooleanArray =
1446 std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| {
1447 internal_datafusion_err!("right bitmap should be available")
1448 })?;
1449
1450 let right_batch = self.current_right_batch.take();
1451 let cur_right_batch = unwrap_or_internal_err!(right_batch);
1452
1453 let left_data = self.get_left_data()?;
1454 let left_schema = left_data.batch().schema();
1455
1456 let res = build_unmatched_batch(
1457 Arc::clone(&self.output_schema),
1458 &cur_right_batch,
1459 right_batch_bitmap,
1460 left_schema,
1461 &self.column_indices,
1462 self.join_type,
1463 JoinSide::Right,
1464 );
1465
1466 self.current_right_batch_matched = None;
1468
1469 res
1470 }
1471
1472 fn get_left_data(&self) -> Result<&Arc<JoinLeftData>> {
1476 self.buffered_left_data
1477 .as_ref()
1478 .ok_or_else(|| internal_datafusion_err!("LeftData should be available"))
1479 }
1480
1481 fn maybe_flush_ready_batch(&mut self) -> Option<Poll<Option<Result<RecordBatch>>>> {
1484 if self.output_buffer.has_completed_batch() {
1485 if let Some(batch) = self.output_buffer.next_completed_batch() {
1486 self.metrics.join_metrics.output_batches.add(1);
1489
1490 let output_rows = batch.num_rows();
1492 self.metrics.selectivity.add_part(output_rows);
1493
1494 return Some(Poll::Ready(Some(Ok(batch))));
1495 }
1496 }
1497
1498 None
1499 }
1500
1501 fn update_matched_bitmap(
1517 &mut self,
1518 l_index: usize,
1519 r_matched_bitmap: &BooleanArray,
1520 ) -> Result<()> {
1521 let left_data = self.get_left_data()?;
1522
1523 let joined_len = r_matched_bitmap.true_count();
1525
1526 if need_produce_result_in_final(self.join_type) && (joined_len > 0) {
1528 let mut bitmap = left_data.bitmap().lock();
1529 bitmap.set_bit(l_index, true);
1530 }
1531
1532 if self.should_track_unmatched_right {
1534 debug_assert!(self.current_right_batch_matched.is_some());
1535 let right_bitmap = std::mem::take(&mut self.current_right_batch_matched)
1537 .ok_or_else(|| {
1538 internal_datafusion_err!("right batch's bitmap should be present")
1539 })?;
1540 let (buf, nulls) = right_bitmap.into_parts();
1541 debug_assert!(nulls.is_none());
1542 let updated_right_bitmap = buf.bitor(r_matched_bitmap.values());
1543
1544 self.current_right_batch_matched =
1545 Some(BooleanArray::new(updated_right_bitmap, None));
1546 }
1547
1548 Ok(())
1549 }
1550}
1551
1552fn apply_filter_to_row_join_batch(
1558 left_batch: &RecordBatch,
1559 l_index: usize,
1560 right_batch: &RecordBatch,
1561 filter: &JoinFilter,
1562) -> Result<BooleanArray> {
1563 debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0);
1564
1565 let intermediate_batch = if filter.schema.fields().is_empty() {
1566 create_record_batch_with_empty_schema(
1569 Arc::new((*filter.schema).clone()),
1570 right_batch.num_rows(),
1571 )?
1572 } else {
1573 build_row_join_batch(
1574 &filter.schema,
1575 left_batch,
1576 l_index,
1577 right_batch,
1578 None,
1579 &filter.column_indices,
1580 JoinSide::Left,
1581 )?
1582 .ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))?
1583 };
1584
1585 let filter_result = filter
1586 .expression()
1587 .evaluate(&intermediate_batch)?
1588 .into_array(intermediate_batch.num_rows())?;
1589 let filter_arr = as_boolean_array(&filter_result)?;
1590
1591 let (is_filtered, nulls) = filter_arr.clone().into_parts();
1596 let bitmap_combined = match nulls {
1597 Some(nulls) => {
1598 let combined = nulls.inner() & &is_filtered;
1599 BooleanArray::new(combined, None)
1600 }
1601 None => BooleanArray::new(is_filtered, None),
1602 };
1603
1604 Ok(bitmap_combined)
1605}
1606
1607fn build_row_join_batch(
1655 output_schema: &Schema,
1656 build_side_batch: &RecordBatch,
1657 build_side_index: usize,
1658 probe_side_batch: &RecordBatch,
1659 probe_side_filter: Option<BooleanArray>,
1660 col_indices: &[ColumnIndex],
1662 build_side: JoinSide,
1665) -> Result<Option<RecordBatch>> {
1666 debug_assert!(build_side != JoinSide::None);
1667
1668 let filtered_probe_batch = if let Some(filter) = probe_side_filter {
1671 &filter_record_batch(probe_side_batch, &filter)?
1672 } else {
1673 probe_side_batch
1674 };
1675
1676 if filtered_probe_batch.num_rows() == 0 {
1677 return Ok(None);
1678 }
1679
1680 if output_schema.fields.is_empty() {
1688 return Ok(Some(create_record_batch_with_empty_schema(
1689 Arc::new(output_schema.clone()),
1690 filtered_probe_batch.num_rows(),
1691 )?));
1692 }
1693
1694 let mut columns: Vec<Arc<dyn Array>> =
1695 Vec::with_capacity(output_schema.fields().len());
1696
1697 for column_index in col_indices {
1698 let array = if column_index.side == build_side {
1699 let original_left_array = build_side_batch.column(column_index.index);
1702 match original_left_array.data_type() {
1708 DataType::List(field) | DataType::LargeList(field)
1709 if field.data_type() == &DataType::Utf8View =>
1710 {
1711 let indices_iter = std::iter::repeat_n(
1712 build_side_index as u64,
1713 filtered_probe_batch.num_rows(),
1714 );
1715 let indices_array = UInt64Array::from_iter_values(indices_iter);
1716 take(original_left_array.as_ref(), &indices_array, None)?
1717 }
1718 _ => {
1719 let scalar_value = ScalarValue::try_from_array(
1720 original_left_array.as_ref(),
1721 build_side_index,
1722 )?;
1723 scalar_value.to_array_of_size(filtered_probe_batch.num_rows())?
1724 }
1725 }
1726 } else {
1727 Arc::clone(filtered_probe_batch.column(column_index.index))
1729 };
1730
1731 columns.push(array);
1732 }
1733
1734 Ok(Some(RecordBatch::try_new(
1735 Arc::new(output_schema.clone()),
1736 columns,
1737 )?))
1738}
1739
1740fn build_unmatched_batch_empty_schema(
1747 output_schema: SchemaRef,
1748 batch_bitmap: &BooleanArray,
1749 join_type: JoinType,
1751) -> Result<Option<RecordBatch>> {
1752 let result_size = match join_type {
1753 JoinType::Left
1754 | JoinType::Right
1755 | JoinType::Full
1756 | JoinType::LeftAnti
1757 | JoinType::RightAnti => batch_bitmap.false_count(),
1758 JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(),
1759 JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(),
1760 _ => unreachable!(),
1761 };
1762
1763 if output_schema.fields().is_empty() {
1764 Ok(Some(create_record_batch_with_empty_schema(
1765 Arc::clone(&output_schema),
1766 result_size,
1767 )?))
1768 } else {
1769 Ok(None)
1770 }
1771}
1772
1773fn create_record_batch_with_empty_schema(
1777 schema: SchemaRef,
1778 row_count: usize,
1779) -> Result<RecordBatch> {
1780 let options = RecordBatchOptions::new()
1781 .with_match_field_names(true)
1782 .with_row_count(Some(row_count));
1783
1784 RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| {
1785 internal_datafusion_err!("Failed to create empty record batch: {}", e)
1786 })
1787}
1788
1789fn build_unmatched_batch(
1825 output_schema: SchemaRef,
1826 batch: &RecordBatch,
1827 batch_bitmap: BooleanArray,
1828 another_side_schema: SchemaRef,
1830 col_indices: &[ColumnIndex],
1831 join_type: JoinType,
1832 batch_side: JoinSide,
1833) -> Result<Option<RecordBatch>> {
1834 debug_assert_ne!(join_type, JoinType::Inner);
1836 debug_assert_ne!(batch_side, JoinSide::None);
1837
1838 if let Some(batch) = build_unmatched_batch_empty_schema(
1840 Arc::clone(&output_schema),
1841 &batch_bitmap,
1842 join_type,
1843 )? {
1844 return Ok(Some(batch));
1845 }
1846
1847 match join_type {
1848 JoinType::Full | JoinType::Right | JoinType::Left => {
1849 if join_type == JoinType::Right {
1850 debug_assert_eq!(batch_side, JoinSide::Right);
1851 }
1852 if join_type == JoinType::Left {
1853 debug_assert_eq!(batch_side, JoinSide::Left);
1854 }
1855
1856 let flipped_bitmap = not(&batch_bitmap)?;
1859
1860 let left_null_columns: Vec<Arc<dyn Array>> = another_side_schema
1862 .fields()
1863 .iter()
1864 .map(|field| new_null_array(field.data_type(), 1))
1865 .collect();
1866
1867 let nullable_left_schema = Arc::new(Schema::new(
1871 another_side_schema
1872 .fields()
1873 .iter()
1874 .map(|field| {
1875 (**field).clone().with_nullable(true)
1876 })
1877 .collect::<Vec<_>>(),
1878 ));
1879 let left_null_batch = if nullable_left_schema.fields.is_empty() {
1880 create_record_batch_with_empty_schema(nullable_left_schema, 0)?
1883 } else {
1884 RecordBatch::try_new(nullable_left_schema, left_null_columns)?
1885 };
1886
1887 debug_assert_ne!(batch_side, JoinSide::None);
1888 let opposite_side = batch_side.negate();
1889
1890 build_row_join_batch(&output_schema, &left_null_batch, 0, batch, Some(flipped_bitmap), col_indices, opposite_side)
1891
1892 },
1893 JoinType::RightSemi | JoinType::RightAnti | JoinType::LeftSemi | JoinType::LeftAnti => {
1894 if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) {
1895 debug_assert_eq!(batch_side, JoinSide::Right);
1896 }
1897 if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
1898 debug_assert_eq!(batch_side, JoinSide::Left);
1899 }
1900
1901 let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi) {
1902 batch_bitmap.clone()
1903 } else {
1904 not(&batch_bitmap)?
1905 };
1906
1907 if bitmap.true_count() == 0 {
1908 return Ok(None);
1909 }
1910
1911 let mut columns: Vec<Arc<dyn Array>> =
1912 Vec::with_capacity(output_schema.fields().len());
1913
1914 for column_index in col_indices {
1915 debug_assert!(column_index.side == batch_side);
1916
1917 let col = batch.column(column_index.index);
1918 let filtered_col = filter(col, &bitmap)?;
1919
1920 columns.push(filtered_col);
1921 }
1922
1923 Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1924 },
1925 JoinType::RightMark | JoinType::LeftMark => {
1926 if join_type == JoinType::RightMark {
1927 debug_assert_eq!(batch_side, JoinSide::Right);
1928 }
1929 if join_type == JoinType::LeftMark {
1930 debug_assert_eq!(batch_side, JoinSide::Left);
1931 }
1932
1933 let mut columns: Vec<Arc<dyn Array>> =
1934 Vec::with_capacity(output_schema.fields().len());
1935
1936 let mut right_batch_bitmap_opt = Some(batch_bitmap);
1938
1939 for column_index in col_indices {
1940 if column_index.side == batch_side {
1941 let col = batch.column(column_index.index);
1942
1943 columns.push(Arc::clone(col));
1944 } else if column_index.side == JoinSide::None {
1945 let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt);
1946 match right_batch_bitmap {
1947 Some(right_batch_bitmap) => {columns.push(Arc::new(right_batch_bitmap))},
1948 None => unreachable!("Should only be one mark column"),
1949 }
1950 } else {
1951 return internal_err!("Not possible to have this join side for RightMark join");
1952 }
1953 }
1954
1955 Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?))
1956 }
1957 _ => internal_err!("If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"),
1958 }
1959}
1960
1961#[cfg(test)]
1962pub(crate) mod tests {
1963 use super::*;
1964 use crate::test::{assert_join_metrics, TestMemoryExec};
1965 use crate::{
1966 common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
1967 };
1968
1969 use arrow::compute::SortOptions;
1970 use arrow::datatypes::{DataType, Field};
1971 use datafusion_common::test_util::batches_to_sort_string;
1972 use datafusion_common::{assert_contains, ScalarValue};
1973 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1974 use datafusion_expr::Operator;
1975 use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
1976 use datafusion_physical_expr::{Partitioning, PhysicalExpr};
1977 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1978
1979 use insta::allow_duplicates;
1980 use insta::assert_snapshot;
1981 use rstest::rstest;
1982
1983 fn build_table(
1984 a: (&str, &Vec<i32>),
1985 b: (&str, &Vec<i32>),
1986 c: (&str, &Vec<i32>),
1987 batch_size: Option<usize>,
1988 sorted_column_names: Vec<&str>,
1989 ) -> Arc<dyn ExecutionPlan> {
1990 let batch = build_table_i32(a, b, c);
1991 let schema = batch.schema();
1992
1993 let batches = if let Some(batch_size) = batch_size {
1994 let num_batches = batch.num_rows().div_ceil(batch_size);
1995 (0..num_batches)
1996 .map(|i| {
1997 let start = i * batch_size;
1998 let remaining_rows = batch.num_rows() - start;
1999 batch.slice(start, batch_size.min(remaining_rows))
2000 })
2001 .collect::<Vec<_>>()
2002 } else {
2003 vec![batch]
2004 };
2005
2006 let mut sort_info = vec![];
2007 for name in sorted_column_names {
2008 let index = schema.index_of(name).unwrap();
2009 let sort_expr = PhysicalSortExpr::new(
2010 Arc::new(Column::new(name, index)),
2011 SortOptions::new(false, false),
2012 );
2013 sort_info.push(sort_expr);
2014 }
2015 let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap();
2016 if let Some(ordering) = LexOrdering::new(sort_info) {
2017 source = source.try_with_sort_information(vec![ordering]).unwrap();
2018 }
2019
2020 Arc::new(TestMemoryExec::update_cache(Arc::new(source)))
2021 }
2022
2023 fn build_left_table() -> Arc<dyn ExecutionPlan> {
2024 build_table(
2025 ("a1", &vec![5, 9, 11]),
2026 ("b1", &vec![5, 8, 8]),
2027 ("c1", &vec![50, 90, 110]),
2028 None,
2029 Vec::new(),
2030 )
2031 }
2032
2033 fn build_right_table() -> Arc<dyn ExecutionPlan> {
2034 build_table(
2035 ("a2", &vec![12, 2, 10]),
2036 ("b2", &vec![10, 2, 10]),
2037 ("c2", &vec![40, 80, 100]),
2038 None,
2039 Vec::new(),
2040 )
2041 }
2042
2043 fn prepare_join_filter() -> JoinFilter {
2044 let column_indices = vec![
2045 ColumnIndex {
2046 index: 1,
2047 side: JoinSide::Left,
2048 },
2049 ColumnIndex {
2050 index: 1,
2051 side: JoinSide::Right,
2052 },
2053 ];
2054 let intermediate_schema = Schema::new(vec![
2055 Field::new("x", DataType::Int32, true),
2056 Field::new("x", DataType::Int32, true),
2057 ]);
2058 let left_filter = Arc::new(BinaryExpr::new(
2060 Arc::new(Column::new("x", 0)),
2061 Operator::NotEq,
2062 Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
2063 )) as Arc<dyn PhysicalExpr>;
2064 let right_filter = Arc::new(BinaryExpr::new(
2066 Arc::new(Column::new("x", 1)),
2067 Operator::NotEq,
2068 Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2069 )) as Arc<dyn PhysicalExpr>;
2070 let filter_expression =
2081 Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
2082 as Arc<dyn PhysicalExpr>;
2083
2084 JoinFilter::new(
2085 filter_expression,
2086 column_indices,
2087 Arc::new(intermediate_schema),
2088 )
2089 }
2090
2091 pub(crate) async fn multi_partitioned_join_collect(
2092 left: Arc<dyn ExecutionPlan>,
2093 right: Arc<dyn ExecutionPlan>,
2094 join_type: &JoinType,
2095 join_filter: Option<JoinFilter>,
2096 context: Arc<TaskContext>,
2097 ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
2098 let partition_count = 4;
2099
2100 let right = Arc::new(RepartitionExec::try_new(
2102 right,
2103 Partitioning::RoundRobinBatch(partition_count),
2104 )?) as Arc<dyn ExecutionPlan>;
2105
2106 let nested_loop_join =
2108 NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?;
2109 let columns = columns(&nested_loop_join.schema());
2110 let mut batches = vec![];
2111 for i in 0..partition_count {
2112 let stream = nested_loop_join.execute(i, Arc::clone(&context))?;
2113 let more_batches = common::collect(stream).await?;
2114 batches.extend(
2115 more_batches
2116 .into_iter()
2117 .inspect(|b| {
2118 assert!(b.num_rows() <= context.session_config().batch_size())
2119 })
2120 .filter(|b| b.num_rows() > 0)
2121 .collect::<Vec<_>>(),
2122 );
2123 }
2124
2125 let metrics = nested_loop_join.metrics().unwrap();
2126
2127 Ok((columns, batches, metrics))
2128 }
2129
2130 fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
2131 let base = TaskContext::default();
2132 let cfg = base.session_config().clone().with_batch_size(batch_size);
2134 Arc::new(base.with_session_config(cfg))
2135 }
2136
2137 #[rstest]
2138 #[tokio::test]
2139 async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2140 let task_ctx = new_task_ctx(batch_size);
2141 dbg!(&batch_size);
2142 let left = build_left_table();
2143 let right = build_right_table();
2144 let filter = prepare_join_filter();
2145 let (columns, batches, metrics) = multi_partitioned_join_collect(
2146 left,
2147 right,
2148 &JoinType::Inner,
2149 Some(filter),
2150 task_ctx,
2151 )
2152 .await?;
2153
2154 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2155 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2156 +----+----+----+----+----+----+
2157 | a1 | b1 | c1 | a2 | b2 | c2 |
2158 +----+----+----+----+----+----+
2159 | 5 | 5 | 50 | 2 | 2 | 80 |
2160 +----+----+----+----+----+----+
2161 "#));
2162
2163 assert_join_metrics!(metrics, 1);
2164
2165 Ok(())
2166 }
2167
2168 #[rstest]
2169 #[tokio::test]
2170 async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2171 let task_ctx = new_task_ctx(batch_size);
2172 let left = build_left_table();
2173 let right = build_right_table();
2174
2175 let filter = prepare_join_filter();
2176 let (columns, batches, metrics) = multi_partitioned_join_collect(
2177 left,
2178 right,
2179 &JoinType::Left,
2180 Some(filter),
2181 task_ctx,
2182 )
2183 .await?;
2184 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2185 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2186 +----+----+-----+----+----+----+
2187 | a1 | b1 | c1 | a2 | b2 | c2 |
2188 +----+----+-----+----+----+----+
2189 | 11 | 8 | 110 | | | |
2190 | 5 | 5 | 50 | 2 | 2 | 80 |
2191 | 9 | 8 | 90 | | | |
2192 +----+----+-----+----+----+----+
2193 "#));
2194
2195 assert_join_metrics!(metrics, 3);
2196
2197 Ok(())
2198 }
2199
2200 #[rstest]
2201 #[tokio::test]
2202 async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2203 let task_ctx = new_task_ctx(batch_size);
2204 let left = build_left_table();
2205 let right = build_right_table();
2206
2207 let filter = prepare_join_filter();
2208 let (columns, batches, metrics) = multi_partitioned_join_collect(
2209 left,
2210 right,
2211 &JoinType::Right,
2212 Some(filter),
2213 task_ctx,
2214 )
2215 .await?;
2216 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2217 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2218 +----+----+----+----+----+-----+
2219 | a1 | b1 | c1 | a2 | b2 | c2 |
2220 +----+----+----+----+----+-----+
2221 | | | | 10 | 10 | 100 |
2222 | | | | 12 | 10 | 40 |
2223 | 5 | 5 | 50 | 2 | 2 | 80 |
2224 +----+----+----+----+----+-----+
2225 "#));
2226
2227 assert_join_metrics!(metrics, 3);
2228
2229 Ok(())
2230 }
2231
2232 #[rstest]
2233 #[tokio::test]
2234 async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> {
2235 let task_ctx = new_task_ctx(batch_size);
2236 let left = build_left_table();
2237 let right = build_right_table();
2238
2239 let filter = prepare_join_filter();
2240 let (columns, batches, metrics) = multi_partitioned_join_collect(
2241 left,
2242 right,
2243 &JoinType::Full,
2244 Some(filter),
2245 task_ctx,
2246 )
2247 .await?;
2248 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
2249 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2250 +----+----+-----+----+----+-----+
2251 | a1 | b1 | c1 | a2 | b2 | c2 |
2252 +----+----+-----+----+----+-----+
2253 | | | | 10 | 10 | 100 |
2254 | | | | 12 | 10 | 40 |
2255 | 11 | 8 | 110 | | | |
2256 | 5 | 5 | 50 | 2 | 2 | 80 |
2257 | 9 | 8 | 90 | | | |
2258 +----+----+-----+----+----+-----+
2259 "#));
2260
2261 assert_join_metrics!(metrics, 5);
2262
2263 Ok(())
2264 }
2265
2266 #[rstest]
2267 #[tokio::test]
2268 async fn join_left_semi_with_filter(
2269 #[values(1, 2, 16)] batch_size: usize,
2270 ) -> Result<()> {
2271 let task_ctx = new_task_ctx(batch_size);
2272 let left = build_left_table();
2273 let right = build_right_table();
2274
2275 let filter = prepare_join_filter();
2276 let (columns, batches, metrics) = multi_partitioned_join_collect(
2277 left,
2278 right,
2279 &JoinType::LeftSemi,
2280 Some(filter),
2281 task_ctx,
2282 )
2283 .await?;
2284 assert_eq!(columns, vec!["a1", "b1", "c1"]);
2285 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2286 +----+----+----+
2287 | a1 | b1 | c1 |
2288 +----+----+----+
2289 | 5 | 5 | 50 |
2290 +----+----+----+
2291 "#));
2292
2293 assert_join_metrics!(metrics, 1);
2294
2295 Ok(())
2296 }
2297
2298 #[rstest]
2299 #[tokio::test]
2300 async fn join_left_anti_with_filter(
2301 #[values(1, 2, 16)] batch_size: usize,
2302 ) -> Result<()> {
2303 let task_ctx = new_task_ctx(batch_size);
2304 let left = build_left_table();
2305 let right = build_right_table();
2306
2307 let filter = prepare_join_filter();
2308 let (columns, batches, metrics) = multi_partitioned_join_collect(
2309 left,
2310 right,
2311 &JoinType::LeftAnti,
2312 Some(filter),
2313 task_ctx,
2314 )
2315 .await?;
2316 assert_eq!(columns, vec!["a1", "b1", "c1"]);
2317 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2318 +----+----+-----+
2319 | a1 | b1 | c1 |
2320 +----+----+-----+
2321 | 11 | 8 | 110 |
2322 | 9 | 8 | 90 |
2323 +----+----+-----+
2324 "#));
2325
2326 assert_join_metrics!(metrics, 2);
2327
2328 Ok(())
2329 }
2330
2331 #[tokio::test]
2332 async fn join_has_correct_stats() -> Result<()> {
2333 let left = build_left_table();
2334 let right = build_right_table();
2335 let nested_loop_join = NestedLoopJoinExec::try_new(
2336 left,
2337 right,
2338 None,
2339 &JoinType::Left,
2340 Some(vec![1, 2]),
2341 )?;
2342 let stats = nested_loop_join.partition_statistics(None)?;
2343 assert_eq!(
2344 nested_loop_join.schema().fields().len(),
2345 stats.column_statistics.len(),
2346 );
2347 assert_eq!(2, stats.column_statistics.len());
2348 Ok(())
2349 }
2350
2351 #[rstest]
2352 #[tokio::test]
2353 async fn join_right_semi_with_filter(
2354 #[values(1, 2, 16)] batch_size: usize,
2355 ) -> Result<()> {
2356 let task_ctx = new_task_ctx(batch_size);
2357 let left = build_left_table();
2358 let right = build_right_table();
2359
2360 let filter = prepare_join_filter();
2361 let (columns, batches, metrics) = multi_partitioned_join_collect(
2362 left,
2363 right,
2364 &JoinType::RightSemi,
2365 Some(filter),
2366 task_ctx,
2367 )
2368 .await?;
2369 assert_eq!(columns, vec!["a2", "b2", "c2"]);
2370 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2371 +----+----+----+
2372 | a2 | b2 | c2 |
2373 +----+----+----+
2374 | 2 | 2 | 80 |
2375 +----+----+----+
2376 "#));
2377
2378 assert_join_metrics!(metrics, 1);
2379
2380 Ok(())
2381 }
2382
2383 #[rstest]
2384 #[tokio::test]
2385 async fn join_right_anti_with_filter(
2386 #[values(1, 2, 16)] batch_size: usize,
2387 ) -> Result<()> {
2388 let task_ctx = new_task_ctx(batch_size);
2389 let left = build_left_table();
2390 let right = build_right_table();
2391
2392 let filter = prepare_join_filter();
2393 let (columns, batches, metrics) = multi_partitioned_join_collect(
2394 left,
2395 right,
2396 &JoinType::RightAnti,
2397 Some(filter),
2398 task_ctx,
2399 )
2400 .await?;
2401 assert_eq!(columns, vec!["a2", "b2", "c2"]);
2402 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2403 +----+----+-----+
2404 | a2 | b2 | c2 |
2405 +----+----+-----+
2406 | 10 | 10 | 100 |
2407 | 12 | 10 | 40 |
2408 +----+----+-----+
2409 "#));
2410
2411 assert_join_metrics!(metrics, 2);
2412
2413 Ok(())
2414 }
2415
2416 #[rstest]
2417 #[tokio::test]
2418 async fn join_left_mark_with_filter(
2419 #[values(1, 2, 16)] batch_size: usize,
2420 ) -> Result<()> {
2421 let task_ctx = new_task_ctx(batch_size);
2422 let left = build_left_table();
2423 let right = build_right_table();
2424
2425 let filter = prepare_join_filter();
2426 let (columns, batches, metrics) = multi_partitioned_join_collect(
2427 left,
2428 right,
2429 &JoinType::LeftMark,
2430 Some(filter),
2431 task_ctx,
2432 )
2433 .await?;
2434 assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
2435 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2436 +----+----+-----+-------+
2437 | a1 | b1 | c1 | mark |
2438 +----+----+-----+-------+
2439 | 11 | 8 | 110 | false |
2440 | 5 | 5 | 50 | true |
2441 | 9 | 8 | 90 | false |
2442 +----+----+-----+-------+
2443 "#));
2444
2445 assert_join_metrics!(metrics, 3);
2446
2447 Ok(())
2448 }
2449
2450 #[rstest]
2451 #[tokio::test]
2452 async fn join_right_mark_with_filter(
2453 #[values(1, 2, 16)] batch_size: usize,
2454 ) -> Result<()> {
2455 let task_ctx = new_task_ctx(batch_size);
2456 let left = build_left_table();
2457 let right = build_right_table();
2458
2459 let filter = prepare_join_filter();
2460 let (columns, batches, metrics) = multi_partitioned_join_collect(
2461 left,
2462 right,
2463 &JoinType::RightMark,
2464 Some(filter),
2465 task_ctx,
2466 )
2467 .await?;
2468 assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
2469
2470 allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#"
2471 +----+----+-----+-------+
2472 | a2 | b2 | c2 | mark |
2473 +----+----+-----+-------+
2474 | 10 | 10 | 100 | false |
2475 | 12 | 10 | 40 | false |
2476 | 2 | 2 | 80 | true |
2477 +----+----+-----+-------+
2478 "#));
2479
2480 assert_join_metrics!(metrics, 3);
2481
2482 Ok(())
2483 }
2484
2485 #[tokio::test]
2486 async fn test_overallocation() -> Result<()> {
2487 let left = build_table(
2488 ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2489 ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2490 ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
2491 None,
2492 Vec::new(),
2493 );
2494 let right = build_table(
2495 ("a2", &vec![10, 11]),
2496 ("b2", &vec![12, 13]),
2497 ("c2", &vec![14, 15]),
2498 None,
2499 Vec::new(),
2500 );
2501 let filter = prepare_join_filter();
2502
2503 let join_types = vec![
2504 JoinType::Inner,
2505 JoinType::Left,
2506 JoinType::Right,
2507 JoinType::Full,
2508 JoinType::LeftSemi,
2509 JoinType::LeftAnti,
2510 JoinType::LeftMark,
2511 JoinType::RightSemi,
2512 JoinType::RightAnti,
2513 JoinType::RightMark,
2514 ];
2515
2516 for join_type in join_types {
2517 let runtime = RuntimeEnvBuilder::new()
2518 .with_memory_limit(100, 1.0)
2519 .build_arc()?;
2520 let task_ctx = TaskContext::default().with_runtime(runtime);
2521 let task_ctx = Arc::new(task_ctx);
2522
2523 let err = multi_partitioned_join_collect(
2524 Arc::clone(&left),
2525 Arc::clone(&right),
2526 &join_type,
2527 Some(filter.clone()),
2528 task_ctx,
2529 )
2530 .await
2531 .unwrap_err();
2532
2533 assert_contains!(
2534 err.to_string(),
2535 "Resources exhausted: Additional allocation failed for NestedLoopJoinLoad[0] with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]"
2536 );
2537 }
2538
2539 Ok(())
2540 }
2541
2542 fn columns(schema: &Schema) -> Vec<String> {
2544 schema.fields().iter().map(|f| f.name().clone()).collect()
2545 }
2546}