1use std::{any::Any, sync::Arc, task::Poll};
22
23use super::utils::{
24 adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter,
25 BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
26 StatefulStreamResult,
27};
28use crate::execution_plan::{boundedness_from_children, EmissionType};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::{
31 join_allows_pushdown, join_table_borders, new_join_children,
32 physical_to_column_exprs, ProjectionExec,
33};
34use crate::{
35 handle_state, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution,
36 ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
37 SendableRecordBatchStream, Statistics,
38};
39
40use arrow::array::{RecordBatch, RecordBatchOptions};
41use arrow::compute::concat_batches;
42use arrow::datatypes::{Fields, Schema, SchemaRef};
43use datafusion_common::stats::Precision;
44use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
45use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
46use datafusion_execution::TaskContext;
47use datafusion_physical_expr::equivalence::join_equivalence_properties;
48
49use async_trait::async_trait;
50use futures::{ready, Stream, StreamExt, TryStreamExt};
51
52#[derive(Debug)]
54struct JoinLeftData {
55 merged_batch: RecordBatch,
57 _reservation: MemoryReservation,
60}
61
62#[allow(rustdoc::private_intra_doc_links)]
63#[derive(Debug)]
78pub struct CrossJoinExec {
79 pub left: Arc<dyn ExecutionPlan>,
81 pub right: Arc<dyn ExecutionPlan>,
83 schema: SchemaRef,
85 left_fut: OnceAsync<JoinLeftData>,
92 metrics: ExecutionPlanMetricsSet,
94 cache: PlanProperties,
96}
97
98impl CrossJoinExec {
99 pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
101 let (all_columns, metadata) = {
103 let left_schema = left.schema();
104 let right_schema = right.schema();
105 let left_fields = left_schema.fields().iter();
106 let right_fields = right_schema.fields().iter();
107
108 let mut metadata = left_schema.metadata().clone();
109 metadata.extend(right_schema.metadata().clone());
110
111 (
112 left_fields.chain(right_fields).cloned().collect::<Fields>(),
113 metadata,
114 )
115 };
116
117 let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
118 let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap();
119
120 CrossJoinExec {
121 left,
122 right,
123 schema,
124 left_fut: Default::default(),
125 metrics: ExecutionPlanMetricsSet::default(),
126 cache,
127 }
128 }
129
130 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
132 &self.left
133 }
134
135 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
137 &self.right
138 }
139
140 fn compute_properties(
142 left: &Arc<dyn ExecutionPlan>,
143 right: &Arc<dyn ExecutionPlan>,
144 schema: SchemaRef,
145 ) -> Result<PlanProperties> {
146 let eq_properties = join_equivalence_properties(
150 left.equivalence_properties().clone(),
151 right.equivalence_properties().clone(),
152 &JoinType::Full,
153 schema,
154 &[false, false],
155 None,
156 &[],
157 )?;
158
159 let output_partitioning = adjust_right_output_partitioning(
163 right.output_partitioning(),
164 left.schema().fields.len(),
165 )?;
166
167 Ok(PlanProperties::new(
168 eq_properties,
169 output_partitioning,
170 EmissionType::Final,
171 boundedness_from_children([left, right]),
172 ))
173 }
174
175 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
185 let new_join =
186 CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
187 reorder_output_after_swap(
188 Arc::new(new_join),
189 &self.left.schema(),
190 &self.right.schema(),
191 )
192 }
193}
194
195async fn load_left_input(
197 stream: SendableRecordBatchStream,
198 metrics: BuildProbeJoinMetrics,
199 reservation: MemoryReservation,
200) -> Result<JoinLeftData> {
201 let left_schema = stream.schema();
202
203 let (batches, _metrics, reservation) = stream
205 .try_fold(
206 (Vec::new(), metrics, reservation),
207 |(mut batches, metrics, mut reservation), batch| async {
208 let batch_size = batch.get_array_memory_size();
209 reservation.try_grow(batch_size)?;
211 metrics.build_mem_used.add(batch_size);
213 metrics.build_input_batches.add(1);
214 metrics.build_input_rows.add(batch.num_rows());
215 batches.push(batch);
217 Ok((batches, metrics, reservation))
218 },
219 )
220 .await?;
221
222 let merged_batch = concat_batches(&left_schema, &batches)?;
223
224 Ok(JoinLeftData {
225 merged_batch,
226 _reservation: reservation,
227 })
228}
229
230impl DisplayAs for CrossJoinExec {
231 fn fmt_as(
232 &self,
233 t: DisplayFormatType,
234 f: &mut std::fmt::Formatter,
235 ) -> std::fmt::Result {
236 match t {
237 DisplayFormatType::Default | DisplayFormatType::Verbose => {
238 write!(f, "CrossJoinExec")
239 }
240 DisplayFormatType::TreeRender => {
241 Ok(())
243 }
244 }
245 }
246}
247
248impl ExecutionPlan for CrossJoinExec {
249 fn name(&self) -> &'static str {
250 "CrossJoinExec"
251 }
252
253 fn as_any(&self) -> &dyn Any {
254 self
255 }
256
257 fn properties(&self) -> &PlanProperties {
258 &self.cache
259 }
260
261 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
262 vec![&self.left, &self.right]
263 }
264
265 fn metrics(&self) -> Option<MetricsSet> {
266 Some(self.metrics.clone_inner())
267 }
268
269 fn with_new_children(
270 self: Arc<Self>,
271 children: Vec<Arc<dyn ExecutionPlan>>,
272 ) -> Result<Arc<dyn ExecutionPlan>> {
273 Ok(Arc::new(CrossJoinExec::new(
274 Arc::clone(&children[0]),
275 Arc::clone(&children[1]),
276 )))
277 }
278
279 fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
280 let new_exec = CrossJoinExec {
281 left: Arc::clone(&self.left),
282 right: Arc::clone(&self.right),
283 schema: Arc::clone(&self.schema),
284 left_fut: Default::default(), metrics: ExecutionPlanMetricsSet::default(),
286 cache: self.cache.clone(),
287 };
288 Ok(Arc::new(new_exec))
289 }
290
291 fn required_input_distribution(&self) -> Vec<Distribution> {
292 vec![
293 Distribution::SinglePartition,
294 Distribution::UnspecifiedDistribution,
295 ]
296 }
297
298 fn execute(
299 &self,
300 partition: usize,
301 context: Arc<TaskContext>,
302 ) -> Result<SendableRecordBatchStream> {
303 if self.left.output_partitioning().partition_count() != 1 {
304 return internal_err!(
305 "Invalid CrossJoinExec, the output partition count of the left child must be 1,\
306 consider using CoalescePartitionsExec or the EnforceDistribution rule"
307 );
308 }
309
310 let stream = self.right.execute(partition, Arc::clone(&context))?;
311
312 let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
313
314 let reservation =
316 MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
317
318 let batch_size = context.session_config().batch_size();
319 let enforce_batch_size_in_joins =
320 context.session_config().enforce_batch_size_in_joins();
321
322 let left_fut = self.left_fut.try_once(|| {
323 let left_stream = self.left.execute(0, context)?;
324
325 Ok(load_left_input(
326 left_stream,
327 join_metrics.clone(),
328 reservation,
329 ))
330 })?;
331
332 if enforce_batch_size_in_joins {
333 Ok(Box::pin(CrossJoinStream {
334 schema: Arc::clone(&self.schema),
335 left_fut,
336 right: stream,
337 left_index: 0,
338 join_metrics,
339 state: CrossJoinStreamState::WaitBuildSide,
340 left_data: RecordBatch::new_empty(self.left().schema()),
341 batch_transformer: BatchSplitter::new(batch_size),
342 }))
343 } else {
344 Ok(Box::pin(CrossJoinStream {
345 schema: Arc::clone(&self.schema),
346 left_fut,
347 right: stream,
348 left_index: 0,
349 join_metrics,
350 state: CrossJoinStreamState::WaitBuildSide,
351 left_data: RecordBatch::new_empty(self.left().schema()),
352 batch_transformer: NoopBatchTransformer::new(),
353 }))
354 }
355 }
356
357 fn statistics(&self) -> Result<Statistics> {
358 self.partition_statistics(None)
359 }
360
361 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
362 let left_stats = self.left.partition_statistics(None)?;
364 let right_stats = self.right.partition_statistics(partition)?;
365
366 Ok(stats_cartesian_product(left_stats, right_stats))
367 }
368
369 fn try_swapping_with_projection(
373 &self,
374 projection: &ProjectionExec,
375 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
376 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
378 else {
379 return Ok(None);
380 };
381
382 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
383 self.left().schema().fields().len(),
384 &projection_as_columns,
385 );
386
387 if !join_allows_pushdown(
388 &projection_as_columns,
389 &self.schema(),
390 far_right_left_col_ind,
391 far_left_right_col_ind,
392 ) {
393 return Ok(None);
394 }
395
396 let (new_left, new_right) = new_join_children(
397 &projection_as_columns,
398 far_right_left_col_ind,
399 far_left_right_col_ind,
400 self.left(),
401 self.right(),
402 )?;
403
404 Ok(Some(Arc::new(CrossJoinExec::new(
405 Arc::new(new_left),
406 Arc::new(new_right),
407 ))))
408 }
409}
410
411fn stats_cartesian_product(
413 left_stats: Statistics,
414 right_stats: Statistics,
415) -> Statistics {
416 let left_row_count = left_stats.num_rows;
417 let right_row_count = right_stats.num_rows;
418
419 let num_rows = left_row_count.multiply(&right_row_count);
421 let total_byte_size = left_stats
423 .total_byte_size
424 .multiply(&right_stats.total_byte_size)
425 .multiply(&Precision::Exact(2));
426
427 let left_col_stats = left_stats.column_statistics;
428 let right_col_stats = right_stats.column_statistics;
429
430 let cross_join_stats = left_col_stats
433 .into_iter()
434 .map(|s| ColumnStatistics {
435 null_count: s.null_count.multiply(&right_row_count),
436 distinct_count: s.distinct_count,
437 min_value: s.min_value,
438 max_value: s.max_value,
439 sum_value: s
440 .sum_value
441 .get_value()
442 .and_then(|v| {
444 Precision::<ScalarValue>::from(right_row_count)
445 .cast_to(&v.data_type())
446 .ok()
447 })
448 .map(|row_count| s.sum_value.multiply(&row_count))
449 .unwrap_or(Precision::Absent),
450 })
451 .chain(right_col_stats.into_iter().map(|s| {
452 ColumnStatistics {
453 null_count: s.null_count.multiply(&left_row_count),
454 distinct_count: s.distinct_count,
455 min_value: s.min_value,
456 max_value: s.max_value,
457 sum_value: s
458 .sum_value
459 .get_value()
460 .and_then(|v| {
462 Precision::<ScalarValue>::from(left_row_count)
463 .cast_to(&v.data_type())
464 .ok()
465 })
466 .map(|row_count| s.sum_value.multiply(&row_count))
467 .unwrap_or(Precision::Absent),
468 }
469 }))
470 .collect();
471
472 Statistics {
473 num_rows,
474 total_byte_size,
475 column_statistics: cross_join_stats,
476 }
477}
478
479struct CrossJoinStream<T> {
481 schema: Arc<Schema>,
483 left_fut: OnceFut<JoinLeftData>,
485 right: SendableRecordBatchStream,
487 left_index: usize,
489 join_metrics: BuildProbeJoinMetrics,
491 state: CrossJoinStreamState,
493 left_data: RecordBatch,
495 batch_transformer: T,
497}
498
499impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
500 fn schema(&self) -> SchemaRef {
501 Arc::clone(&self.schema)
502 }
503}
504
505enum CrossJoinStreamState {
507 WaitBuildSide,
508 FetchProbeBatch,
509 BuildBatches(RecordBatch),
511}
512
513impl CrossJoinStreamState {
514 fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
517 match self {
518 CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
519 _ => internal_err!("Expected RecordBatch in BuildBatches state"),
520 }
521 }
522}
523
524fn build_batch(
525 left_index: usize,
526 batch: &RecordBatch,
527 left_data: &RecordBatch,
528 schema: &Schema,
529) -> Result<RecordBatch> {
530 let arrays = left_data
532 .columns()
533 .iter()
534 .map(|arr| {
535 let scalar = ScalarValue::try_from_array(arr, left_index)?;
536 scalar.to_array_of_size(batch.num_rows())
537 })
538 .collect::<Result<Vec<_>>>()?;
539
540 RecordBatch::try_new_with_options(
541 Arc::new(schema.clone()),
542 arrays
543 .iter()
544 .chain(batch.columns().iter())
545 .cloned()
546 .collect(),
547 &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
548 )
549 .map_err(Into::into)
550}
551
552#[async_trait]
553impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
554 type Item = Result<RecordBatch>;
555
556 fn poll_next(
557 mut self: std::pin::Pin<&mut Self>,
558 cx: &mut std::task::Context<'_>,
559 ) -> Poll<Option<Self::Item>> {
560 self.poll_next_impl(cx)
561 }
562}
563
564impl<T: BatchTransformer> CrossJoinStream<T> {
565 fn poll_next_impl(
568 &mut self,
569 cx: &mut std::task::Context<'_>,
570 ) -> Poll<Option<Result<RecordBatch>>> {
571 loop {
572 return match self.state {
573 CrossJoinStreamState::WaitBuildSide => {
574 handle_state!(ready!(self.collect_build_side(cx)))
575 }
576 CrossJoinStreamState::FetchProbeBatch => {
577 handle_state!(ready!(self.fetch_probe_batch(cx)))
578 }
579 CrossJoinStreamState::BuildBatches(_) => {
580 let poll = handle_state!(self.build_batches());
581 self.join_metrics.baseline.record_poll(poll)
582 }
583 };
584 }
585 }
586
587 fn collect_build_side(
590 &mut self,
591 cx: &mut std::task::Context<'_>,
592 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
593 let build_timer = self.join_metrics.build_time.timer();
594 let left_data = match ready!(self.left_fut.get(cx)) {
595 Ok(left_data) => left_data,
596 Err(e) => return Poll::Ready(Err(e)),
597 };
598 build_timer.done();
599
600 let left_data = left_data.merged_batch.clone();
601 let result = if left_data.num_rows() == 0 {
602 StatefulStreamResult::Ready(None)
603 } else {
604 self.left_data = left_data;
605 self.state = CrossJoinStreamState::FetchProbeBatch;
606 StatefulStreamResult::Continue
607 };
608 Poll::Ready(Ok(result))
609 }
610
611 fn fetch_probe_batch(
614 &mut self,
615 cx: &mut std::task::Context<'_>,
616 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
617 self.left_index = 0;
618 let right_data = match ready!(self.right.poll_next_unpin(cx)) {
619 Some(Ok(right_data)) => right_data,
620 Some(Err(e)) => return Poll::Ready(Err(e)),
621 None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
622 };
623 self.join_metrics.input_batches.add(1);
624 self.join_metrics.input_rows.add(right_data.num_rows());
625
626 self.state = CrossJoinStreamState::BuildBatches(right_data);
627 Poll::Ready(Ok(StatefulStreamResult::Continue))
628 }
629
630 fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
633 let right_batch = self.state.try_as_record_batch()?;
634 if self.left_index < self.left_data.num_rows() {
635 match self.batch_transformer.next() {
636 None => {
637 let join_timer = self.join_metrics.join_time.timer();
638 let result = build_batch(
639 self.left_index,
640 right_batch,
641 &self.left_data,
642 &self.schema,
643 );
644 join_timer.done();
645
646 self.batch_transformer.set_batch(result?);
647 }
648 Some((batch, last)) => {
649 if last {
650 self.left_index += 1;
651 }
652
653 self.join_metrics.output_batches.add(1);
654 return Ok(StatefulStreamResult::Ready(Some(batch)));
655 }
656 }
657 } else {
658 self.state = CrossJoinStreamState::FetchProbeBatch;
659 }
660 Ok(StatefulStreamResult::Continue)
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667 use crate::common;
668 use crate::test::{assert_join_metrics, build_table_scan_i32};
669
670 use datafusion_common::{assert_contains, test_util::batches_to_sort_string};
671 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
672 use insta::assert_snapshot;
673
674 async fn join_collect(
675 left: Arc<dyn ExecutionPlan>,
676 right: Arc<dyn ExecutionPlan>,
677 context: Arc<TaskContext>,
678 ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
679 let join = CrossJoinExec::new(left, right);
680 let columns_header = columns(&join.schema());
681
682 let stream = join.execute(0, context)?;
683 let batches = common::collect(stream).await?;
684 let metrics = join.metrics().unwrap();
685
686 Ok((columns_header, batches, metrics))
687 }
688
689 #[tokio::test]
690 async fn test_stats_cartesian_product() {
691 let left_row_count = 11;
692 let left_bytes = 23;
693 let right_row_count = 7;
694 let right_bytes = 27;
695
696 let left = Statistics {
697 num_rows: Precision::Exact(left_row_count),
698 total_byte_size: Precision::Exact(left_bytes),
699 column_statistics: vec![
700 ColumnStatistics {
701 distinct_count: Precision::Exact(5),
702 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
703 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
704 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
705 null_count: Precision::Exact(0),
706 },
707 ColumnStatistics {
708 distinct_count: Precision::Exact(1),
709 max_value: Precision::Exact(ScalarValue::from("x")),
710 min_value: Precision::Exact(ScalarValue::from("a")),
711 sum_value: Precision::Absent,
712 null_count: Precision::Exact(3),
713 },
714 ],
715 };
716
717 let right = Statistics {
718 num_rows: Precision::Exact(right_row_count),
719 total_byte_size: Precision::Exact(right_bytes),
720 column_statistics: vec![ColumnStatistics {
721 distinct_count: Precision::Exact(3),
722 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
723 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
724 sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
725 null_count: Precision::Exact(2),
726 }],
727 };
728
729 let result = stats_cartesian_product(left, right);
730
731 let expected = Statistics {
732 num_rows: Precision::Exact(left_row_count * right_row_count),
733 total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
734 column_statistics: vec![
735 ColumnStatistics {
736 distinct_count: Precision::Exact(5),
737 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
738 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
739 sum_value: Precision::Exact(ScalarValue::Int64(Some(
740 42 * right_row_count as i64,
741 ))),
742 null_count: Precision::Exact(0),
743 },
744 ColumnStatistics {
745 distinct_count: Precision::Exact(1),
746 max_value: Precision::Exact(ScalarValue::from("x")),
747 min_value: Precision::Exact(ScalarValue::from("a")),
748 sum_value: Precision::Absent,
749 null_count: Precision::Exact(3 * right_row_count),
750 },
751 ColumnStatistics {
752 distinct_count: Precision::Exact(3),
753 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
754 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
755 sum_value: Precision::Exact(ScalarValue::Int64(Some(
756 20 * left_row_count as i64,
757 ))),
758 null_count: Precision::Exact(2 * left_row_count),
759 },
760 ],
761 };
762
763 assert_eq!(result, expected);
764 }
765
766 #[tokio::test]
767 async fn test_stats_cartesian_product_with_unknown_size() {
768 let left_row_count = 11;
769
770 let left = Statistics {
771 num_rows: Precision::Exact(left_row_count),
772 total_byte_size: Precision::Exact(23),
773 column_statistics: vec![
774 ColumnStatistics {
775 distinct_count: Precision::Exact(5),
776 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
777 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
778 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
779 null_count: Precision::Exact(0),
780 },
781 ColumnStatistics {
782 distinct_count: Precision::Exact(1),
783 max_value: Precision::Exact(ScalarValue::from("x")),
784 min_value: Precision::Exact(ScalarValue::from("a")),
785 sum_value: Precision::Absent,
786 null_count: Precision::Exact(3),
787 },
788 ],
789 };
790
791 let right = Statistics {
792 num_rows: Precision::Absent,
793 total_byte_size: Precision::Absent,
794 column_statistics: vec![ColumnStatistics {
795 distinct_count: Precision::Exact(3),
796 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
797 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
798 sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
799 null_count: Precision::Exact(2),
800 }],
801 };
802
803 let result = stats_cartesian_product(left, right);
804
805 let expected = Statistics {
806 num_rows: Precision::Absent,
807 total_byte_size: Precision::Absent,
808 column_statistics: vec![
809 ColumnStatistics {
810 distinct_count: Precision::Exact(5),
811 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
812 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
813 sum_value: Precision::Absent, null_count: Precision::Absent, },
816 ColumnStatistics {
817 distinct_count: Precision::Exact(1),
818 max_value: Precision::Exact(ScalarValue::from("x")),
819 min_value: Precision::Exact(ScalarValue::from("a")),
820 sum_value: Precision::Absent,
821 null_count: Precision::Absent, },
823 ColumnStatistics {
824 distinct_count: Precision::Exact(3),
825 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
826 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
827 sum_value: Precision::Exact(ScalarValue::Int64(Some(
828 20 * left_row_count as i64,
829 ))),
830 null_count: Precision::Exact(2 * left_row_count),
831 },
832 ],
833 };
834
835 assert_eq!(result, expected);
836 }
837
838 #[tokio::test]
839 async fn test_join() -> Result<()> {
840 let task_ctx = Arc::new(TaskContext::default());
841
842 let left = build_table_scan_i32(
843 ("a1", &vec![1, 2, 3]),
844 ("b1", &vec![4, 5, 6]),
845 ("c1", &vec![7, 8, 9]),
846 );
847 let right = build_table_scan_i32(
848 ("a2", &vec![10, 11]),
849 ("b2", &vec![12, 13]),
850 ("c2", &vec![14, 15]),
851 );
852
853 let (columns, batches, metrics) = join_collect(left, right, task_ctx).await?;
854
855 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
856
857 assert_snapshot!(batches_to_sort_string(&batches), @r#"
858 +----+----+----+----+----+----+
859 | a1 | b1 | c1 | a2 | b2 | c2 |
860 +----+----+----+----+----+----+
861 | 1 | 4 | 7 | 10 | 12 | 14 |
862 | 1 | 4 | 7 | 11 | 13 | 15 |
863 | 2 | 5 | 8 | 10 | 12 | 14 |
864 | 2 | 5 | 8 | 11 | 13 | 15 |
865 | 3 | 6 | 9 | 10 | 12 | 14 |
866 | 3 | 6 | 9 | 11 | 13 | 15 |
867 +----+----+----+----+----+----+
868 "#);
869
870 assert_join_metrics!(metrics, 6);
871
872 Ok(())
873 }
874
875 #[tokio::test]
876 async fn test_overallocation() -> Result<()> {
877 let runtime = RuntimeEnvBuilder::new()
878 .with_memory_limit(100, 1.0)
879 .build_arc()?;
880 let task_ctx = TaskContext::default().with_runtime(runtime);
881 let task_ctx = Arc::new(task_ctx);
882
883 let left = build_table_scan_i32(
884 ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
885 ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
886 ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
887 );
888 let right = build_table_scan_i32(
889 ("a2", &vec![10, 11]),
890 ("b2", &vec![12, 13]),
891 ("c2", &vec![14, 15]),
892 );
893
894 let err = join_collect(left, right, task_ctx).await.unwrap_err();
895
896 assert_contains!(
897 err.to_string(),
898 "Resources exhausted: Additional allocation failed for CrossJoinExec with top memory consumers (across reservations) as:\n CrossJoinExec"
899 );
900
901 Ok(())
902 }
903
904 fn columns(schema: &Schema) -> Vec<String> {
906 schema.fields().iter().map(|f| f.name().clone()).collect()
907 }
908}