1use std::{sync::Arc, task::Poll};
22
23use super::utils::{
24 BatchSplitter, BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer,
25 OnceAsync, OnceFut, StatefulStreamResult, adjust_right_output_partitioning,
26 reorder_output_after_swap,
27};
28use crate::execution_plan::{EmissionType, boundedness_from_children};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::{
31 ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
32 physical_to_column_exprs,
33};
34use crate::stream::EmptyRecordBatchStream;
35use crate::{
36 ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
37 ExecutionPlanProperties, PlanProperties, RecordBatchStream,
38 SendableRecordBatchStream, Statistics, check_if_same_properties, handle_state,
39};
40
41use arrow::array::{RecordBatch, RecordBatchOptions};
42use arrow::compute::concat_batches;
43use arrow::datatypes::{Fields, Schema, SchemaRef};
44use datafusion_common::stats::Precision;
45use datafusion_common::{
46 JoinType, Result, ScalarValue, assert_eq_or_internal_err, internal_err,
47};
48use datafusion_execution::TaskContext;
49use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
50use datafusion_physical_expr::equivalence::join_equivalence_properties;
51
52use async_trait::async_trait;
53use futures::{Stream, StreamExt, TryStreamExt, ready};
54
55#[derive(Debug)]
57struct JoinLeftData {
58 merged_batch: RecordBatch,
60 _reservation: MemoryReservation,
63}
64
65#[expect(rustdoc::private_intra_doc_links)]
66#[derive(Debug)]
81pub struct CrossJoinExec {
82 pub left: Arc<dyn ExecutionPlan>,
84 pub right: Arc<dyn ExecutionPlan>,
86 schema: SchemaRef,
88 left_fut: OnceAsync<JoinLeftData>,
95 metrics: ExecutionPlanMetricsSet,
97 cache: Arc<PlanProperties>,
99}
100
101impl CrossJoinExec {
102 pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
104 let (all_columns, metadata) = {
106 let left_schema = left.schema();
107 let right_schema = right.schema();
108 let left_fields = left_schema.fields().iter();
109 let right_fields = right_schema.fields().iter();
110
111 let mut metadata = left_schema.metadata().clone();
112 metadata.extend(right_schema.metadata().clone());
113
114 (
115 left_fields.chain(right_fields).cloned().collect::<Fields>(),
116 metadata,
117 )
118 };
119
120 let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
121 let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap();
122
123 CrossJoinExec {
124 left,
125 right,
126 schema,
127 left_fut: Default::default(),
128 metrics: ExecutionPlanMetricsSet::default(),
129 cache: Arc::new(cache),
130 }
131 }
132
133 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
135 &self.left
136 }
137
138 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
140 &self.right
141 }
142
143 fn compute_properties(
145 left: &Arc<dyn ExecutionPlan>,
146 right: &Arc<dyn ExecutionPlan>,
147 schema: SchemaRef,
148 ) -> Result<PlanProperties> {
149 let eq_properties = join_equivalence_properties(
153 left.equivalence_properties().clone(),
154 right.equivalence_properties().clone(),
155 &JoinType::Full,
156 schema,
157 &[false, false],
158 None,
159 &[],
160 )?;
161
162 let output_partitioning = adjust_right_output_partitioning(
166 right.output_partitioning(),
167 left.schema().fields.len(),
168 )?;
169
170 Ok(PlanProperties::new(
171 eq_properties,
172 output_partitioning,
173 EmissionType::Final,
174 boundedness_from_children([left, right]),
175 ))
176 }
177
178 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
188 let new_join =
189 CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
190 reorder_output_after_swap(
191 Arc::new(new_join),
192 &self.left.schema(),
193 &self.right.schema(),
194 )
195 }
196
197 fn with_new_children_and_same_properties(
198 &self,
199 mut children: Vec<Arc<dyn ExecutionPlan>>,
200 ) -> Self {
201 let left = children.swap_remove(0);
202 let right = children.swap_remove(0);
203
204 Self {
205 left,
206 right,
207 metrics: ExecutionPlanMetricsSet::new(),
208 left_fut: Default::default(),
209 cache: Arc::clone(&self.cache),
210 schema: Arc::clone(&self.schema),
211 }
212 }
213}
214
215async fn load_left_input(
217 stream: SendableRecordBatchStream,
218 metrics: BuildProbeJoinMetrics,
219 reservation: MemoryReservation,
220) -> Result<JoinLeftData> {
221 let left_schema = stream.schema();
222
223 let (batches, _metrics, reservation) = stream
225 .try_fold(
226 (Vec::new(), metrics, reservation),
227 |(mut batches, metrics, reservation), batch| async {
228 let batch_size = batch.get_array_memory_size();
229 reservation.try_grow(batch_size)?;
231 metrics.build_mem_used.add(batch_size);
233 metrics.build_input_batches.add(1);
234 metrics.build_input_rows.add(batch.num_rows());
235 batches.push(batch);
237 Ok((batches, metrics, reservation))
238 },
239 )
240 .await?;
241
242 let merged_batch = concat_batches(&left_schema, &batches)?;
243
244 Ok(JoinLeftData {
245 merged_batch,
246 _reservation: reservation,
247 })
248}
249
250impl DisplayAs for CrossJoinExec {
251 fn fmt_as(
252 &self,
253 t: DisplayFormatType,
254 f: &mut std::fmt::Formatter,
255 ) -> std::fmt::Result {
256 match t {
257 DisplayFormatType::Default | DisplayFormatType::Verbose => {
258 write!(f, "CrossJoinExec")
259 }
260 DisplayFormatType::TreeRender => {
261 Ok(())
263 }
264 }
265 }
266}
267
268impl ExecutionPlan for CrossJoinExec {
269 fn name(&self) -> &'static str {
270 "CrossJoinExec"
271 }
272
273 fn properties(&self) -> &Arc<PlanProperties> {
274 &self.cache
275 }
276
277 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
278 vec![&self.left, &self.right]
279 }
280
281 fn metrics(&self) -> Option<MetricsSet> {
282 Some(self.metrics.clone_inner())
283 }
284
285 fn with_new_children(
286 self: Arc<Self>,
287 children: Vec<Arc<dyn ExecutionPlan>>,
288 ) -> Result<Arc<dyn ExecutionPlan>> {
289 check_if_same_properties!(self, children);
290 Ok(Arc::new(CrossJoinExec::new(
291 Arc::clone(&children[0]),
292 Arc::clone(&children[1]),
293 )))
294 }
295
296 fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
297 let new_exec = CrossJoinExec {
298 left: Arc::clone(&self.left),
299 right: Arc::clone(&self.right),
300 schema: Arc::clone(&self.schema),
301 left_fut: Default::default(), metrics: ExecutionPlanMetricsSet::default(),
303 cache: Arc::clone(&self.cache),
304 };
305 Ok(Arc::new(new_exec))
306 }
307
308 fn required_input_distribution(&self) -> Vec<Distribution> {
309 vec![
310 Distribution::SinglePartition,
311 Distribution::UnspecifiedDistribution,
312 ]
313 }
314
315 fn execute(
316 &self,
317 partition: usize,
318 context: Arc<TaskContext>,
319 ) -> Result<SendableRecordBatchStream> {
320 assert_eq_or_internal_err!(
321 self.left.output_partitioning().partition_count(),
322 1,
323 "Invalid CrossJoinExec, the output partition count of the left child must be 1,\
324 consider using CoalescePartitionsExec or the EnforceDistribution rule"
325 );
326
327 let stream = self.right.execute(partition, Arc::clone(&context))?;
328
329 let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
330
331 let reservation =
333 MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
334
335 let batch_size = context.session_config().batch_size();
336 let enforce_batch_size_in_joins =
337 context.session_config().enforce_batch_size_in_joins();
338
339 let left_fut = self.left_fut.try_once(|| {
340 let left_stream = self.left.execute(0, context)?;
341
342 Ok(load_left_input(
343 left_stream,
344 join_metrics.clone(),
345 reservation,
346 ))
347 })?;
348
349 if enforce_batch_size_in_joins {
350 Ok(Box::pin(CrossJoinStream {
351 schema: Arc::clone(&self.schema),
352 left_fut,
353 right: stream,
354 left_index: 0,
355 join_metrics,
356 state: CrossJoinStreamState::WaitBuildSide,
357 left_data: RecordBatch::new_empty(self.left().schema()),
358 batch_transformer: BatchSplitter::new(batch_size),
359 }))
360 } else {
361 Ok(Box::pin(CrossJoinStream {
362 schema: Arc::clone(&self.schema),
363 left_fut,
364 right: stream,
365 left_index: 0,
366 join_metrics,
367 state: CrossJoinStreamState::WaitBuildSide,
368 left_data: RecordBatch::new_empty(self.left().schema()),
369 batch_transformer: NoopBatchTransformer::new(),
370 }))
371 }
372 }
373
374 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
375 let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(None)?);
377 let right_stats =
378 Arc::unwrap_or_clone(self.right.partition_statistics(partition)?);
379
380 Ok(Arc::new(stats_cartesian_product(left_stats, right_stats)))
381 }
382
383 fn try_swapping_with_projection(
387 &self,
388 projection: &ProjectionExec,
389 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
390 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
392 else {
393 return Ok(None);
394 };
395
396 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
397 self.left().schema().fields().len(),
398 &projection_as_columns,
399 );
400
401 if !join_allows_pushdown(
402 &projection_as_columns,
403 &self.schema(),
404 far_right_left_col_ind,
405 far_left_right_col_ind,
406 ) {
407 return Ok(None);
408 }
409
410 let (new_left, new_right) = new_join_children(
411 &projection_as_columns,
412 far_right_left_col_ind,
413 far_left_right_col_ind,
414 self.left(),
415 self.right(),
416 )?;
417
418 Ok(Some(Arc::new(CrossJoinExec::new(
419 Arc::new(new_left),
420 Arc::new(new_right),
421 ))))
422 }
423}
424
425fn stats_cartesian_product(
427 left_stats: Statistics,
428 right_stats: Statistics,
429) -> Statistics {
430 let left_row_count = left_stats.num_rows;
431 let right_row_count = right_stats.num_rows;
432
433 let num_rows = left_row_count.multiply(&right_row_count);
435 let total_byte_size = left_stats
437 .total_byte_size
438 .multiply(&right_stats.total_byte_size)
439 .multiply(&Precision::Exact(2));
440
441 let left_col_stats = left_stats.column_statistics;
442 let right_col_stats = right_stats.column_statistics;
443
444 let cross_join_stats = left_col_stats
447 .into_iter()
448 .map(|s| {
449 let widened_sum = s.sum_value.cast_to_sum_type();
450 ColumnStatistics {
451 null_count: s.null_count.multiply(&right_row_count),
452 distinct_count: s.distinct_count,
453 min_value: s.min_value,
454 max_value: s.max_value,
455 sum_value: widened_sum
456 .get_value()
457 .and_then(|v| {
459 Precision::<ScalarValue>::from(right_row_count)
460 .cast_to(&v.data_type())
461 .ok()
462 })
463 .map(|row_count| widened_sum.multiply(&row_count))
464 .unwrap_or(Precision::Absent),
465 byte_size: Precision::Absent,
466 }
467 })
468 .chain(right_col_stats.into_iter().map(|s| {
469 let widened_sum = s.sum_value.cast_to_sum_type();
470 ColumnStatistics {
471 null_count: s.null_count.multiply(&left_row_count),
472 distinct_count: s.distinct_count,
473 min_value: s.min_value,
474 max_value: s.max_value,
475 sum_value: widened_sum
476 .get_value()
477 .and_then(|v| {
479 Precision::<ScalarValue>::from(left_row_count)
480 .cast_to(&v.data_type())
481 .ok()
482 })
483 .map(|row_count| widened_sum.multiply(&row_count))
484 .unwrap_or(Precision::Absent),
485 byte_size: Precision::Absent,
486 }
487 }))
488 .collect();
489
490 Statistics {
491 num_rows,
492 total_byte_size,
493 column_statistics: cross_join_stats,
494 }
495}
496
497struct CrossJoinStream<T> {
499 schema: Arc<Schema>,
501 left_fut: OnceFut<JoinLeftData>,
503 right: SendableRecordBatchStream,
505 left_index: usize,
507 join_metrics: BuildProbeJoinMetrics,
509 state: CrossJoinStreamState,
511 left_data: RecordBatch,
513 batch_transformer: T,
515}
516
517impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
518 fn schema(&self) -> SchemaRef {
519 Arc::clone(&self.schema)
520 }
521}
522
523enum CrossJoinStreamState {
525 WaitBuildSide,
526 FetchProbeBatch,
527 BuildBatches(RecordBatch),
529}
530
531impl CrossJoinStreamState {
532 fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
535 match self {
536 CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
537 _ => internal_err!("Expected RecordBatch in BuildBatches state"),
538 }
539 }
540}
541
542fn build_batch(
543 left_index: usize,
544 batch: &RecordBatch,
545 left_data: &RecordBatch,
546 schema: &Schema,
547) -> Result<RecordBatch> {
548 let arrays = left_data
550 .columns()
551 .iter()
552 .map(|arr| {
553 let scalar = ScalarValue::try_from_array(arr, left_index)?;
554 scalar.to_array_of_size(batch.num_rows())
555 })
556 .collect::<Result<Vec<_>>>()?;
557
558 RecordBatch::try_new_with_options(
559 Arc::new(schema.clone()),
560 arrays
561 .iter()
562 .chain(batch.columns().iter())
563 .cloned()
564 .collect(),
565 &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
566 )
567 .map_err(Into::into)
568}
569
570#[async_trait]
571impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
572 type Item = Result<RecordBatch>;
573
574 fn poll_next(
575 mut self: std::pin::Pin<&mut Self>,
576 cx: &mut std::task::Context<'_>,
577 ) -> Poll<Option<Self::Item>> {
578 self.poll_next_impl(cx)
579 }
580}
581
582impl<T: BatchTransformer> CrossJoinStream<T> {
583 fn poll_next_impl(
586 &mut self,
587 cx: &mut std::task::Context<'_>,
588 ) -> Poll<Option<Result<RecordBatch>>> {
589 loop {
590 return match self.state {
591 CrossJoinStreamState::WaitBuildSide => {
592 handle_state!(ready!(self.collect_build_side(cx)))
593 }
594 CrossJoinStreamState::FetchProbeBatch => {
595 handle_state!(ready!(self.fetch_probe_batch(cx)))
596 }
597 CrossJoinStreamState::BuildBatches(_) => {
598 let poll = handle_state!(self.build_batches());
599 self.join_metrics.baseline.record_poll(poll)
600 }
601 };
602 }
603 }
604
605 fn collect_build_side(
608 &mut self,
609 cx: &mut std::task::Context<'_>,
610 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
611 let build_timer = self.join_metrics.build_time.timer();
612 let left_data = match ready!(self.left_fut.get(cx)) {
613 Ok(left_data) => left_data,
614 Err(e) => return Poll::Ready(Err(e)),
615 };
616 build_timer.done();
617
618 let left_data = left_data.merged_batch.clone();
619 let result = if left_data.num_rows() == 0 {
620 StatefulStreamResult::Ready(None)
621 } else {
622 self.left_data = left_data;
623 self.state = CrossJoinStreamState::FetchProbeBatch;
624 StatefulStreamResult::Continue
625 };
626 Poll::Ready(Ok(result))
627 }
628
629 fn fetch_probe_batch(
632 &mut self,
633 cx: &mut std::task::Context<'_>,
634 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
635 self.left_index = 0;
636 let right_data = match ready!(self.right.poll_next_unpin(cx)) {
637 Some(Ok(right_data)) => right_data,
638 Some(Err(e)) => return Poll::Ready(Err(e)),
639 None => {
640 let right_schema = self.right.schema();
642 self.right = Box::pin(EmptyRecordBatchStream::new(right_schema));
643 return Poll::Ready(Ok(StatefulStreamResult::Ready(None)));
644 }
645 };
646 self.join_metrics.input_batches.add(1);
647 self.join_metrics.input_rows.add(right_data.num_rows());
648
649 self.state = CrossJoinStreamState::BuildBatches(right_data);
650 Poll::Ready(Ok(StatefulStreamResult::Continue))
651 }
652
653 fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
656 let right_batch = self.state.try_as_record_batch()?;
657 if self.left_index < self.left_data.num_rows() {
658 match self.batch_transformer.next() {
659 None => {
660 let join_timer = self.join_metrics.join_time.timer();
661 let result = build_batch(
662 self.left_index,
663 right_batch,
664 &self.left_data,
665 &self.schema,
666 );
667 join_timer.done();
668
669 self.batch_transformer.set_batch(result?);
670 }
671 Some((batch, last)) => {
672 if last {
673 self.left_index += 1;
674 }
675
676 return Ok(StatefulStreamResult::Ready(Some(batch)));
677 }
678 }
679 } else {
680 self.state = CrossJoinStreamState::FetchProbeBatch;
681 }
682 Ok(StatefulStreamResult::Continue)
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use crate::common;
690 use crate::test::{assert_join_metrics, build_table_scan_i32};
691
692 use datafusion_common::{assert_contains, test_util::batches_to_sort_string};
693 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
694 use insta::assert_snapshot;
695
696 async fn join_collect(
697 left: Arc<dyn ExecutionPlan>,
698 right: Arc<dyn ExecutionPlan>,
699 context: Arc<TaskContext>,
700 ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
701 let join = CrossJoinExec::new(left, right);
702 let columns_header = columns(&join.schema());
703
704 let stream = join.execute(0, context)?;
705 let batches = common::collect(stream).await?;
706 let metrics = join.metrics().unwrap();
707
708 Ok((columns_header, batches, metrics))
709 }
710
711 #[tokio::test]
712 async fn test_stats_cartesian_product() {
713 let left_row_count = 11;
714 let left_bytes = 23;
715 let right_row_count = 7;
716 let right_bytes = 27;
717
718 let left = Statistics {
719 num_rows: Precision::Exact(left_row_count),
720 total_byte_size: Precision::Exact(left_bytes),
721 column_statistics: vec![
722 ColumnStatistics {
723 distinct_count: Precision::Exact(5),
724 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
725 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
726 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
727 null_count: Precision::Exact(0),
728 byte_size: Precision::Absent,
729 },
730 ColumnStatistics {
731 distinct_count: Precision::Exact(1),
732 max_value: Precision::Exact(ScalarValue::from("x")),
733 min_value: Precision::Exact(ScalarValue::from("a")),
734 sum_value: Precision::Absent,
735 null_count: Precision::Exact(3),
736 byte_size: Precision::Absent,
737 },
738 ],
739 };
740
741 let right = Statistics {
742 num_rows: Precision::Exact(right_row_count),
743 total_byte_size: Precision::Exact(right_bytes),
744 column_statistics: vec![ColumnStatistics {
745 distinct_count: Precision::Exact(3),
746 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
747 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
748 sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
749 null_count: Precision::Exact(2),
750 byte_size: Precision::Absent,
751 }],
752 };
753
754 let result = stats_cartesian_product(left, right);
755
756 let expected = Statistics {
757 num_rows: Precision::Exact(left_row_count * right_row_count),
758 total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
759 column_statistics: vec![
760 ColumnStatistics {
761 distinct_count: Precision::Exact(5),
762 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
763 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
764 sum_value: Precision::Exact(ScalarValue::Int64(Some(
765 42 * right_row_count as i64,
766 ))),
767 null_count: Precision::Exact(0),
768 byte_size: Precision::Absent,
769 },
770 ColumnStatistics {
771 distinct_count: Precision::Exact(1),
772 max_value: Precision::Exact(ScalarValue::from("x")),
773 min_value: Precision::Exact(ScalarValue::from("a")),
774 sum_value: Precision::Absent,
775 null_count: Precision::Exact(3 * right_row_count),
776 byte_size: Precision::Absent,
777 },
778 ColumnStatistics {
779 distinct_count: Precision::Exact(3),
780 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
781 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
782 sum_value: Precision::Exact(ScalarValue::Int64(Some(
783 20 * left_row_count as i64,
784 ))),
785 null_count: Precision::Exact(2 * left_row_count),
786 byte_size: Precision::Absent,
787 },
788 ],
789 };
790
791 assert_eq!(result, expected);
792 }
793
794 #[tokio::test]
795 async fn test_stats_cartesian_product_with_unknown_size() {
796 let left_row_count = 11;
797
798 let left = Statistics {
799 num_rows: Precision::Exact(left_row_count),
800 total_byte_size: Precision::Exact(23),
801 column_statistics: vec![
802 ColumnStatistics {
803 distinct_count: Precision::Exact(5),
804 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
805 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
806 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
807 null_count: Precision::Exact(0),
808 byte_size: Precision::Absent,
809 },
810 ColumnStatistics {
811 distinct_count: Precision::Exact(1),
812 max_value: Precision::Exact(ScalarValue::from("x")),
813 min_value: Precision::Exact(ScalarValue::from("a")),
814 sum_value: Precision::Absent,
815 null_count: Precision::Exact(3),
816 byte_size: Precision::Absent,
817 },
818 ],
819 };
820
821 let right = Statistics {
822 num_rows: Precision::Absent,
823 total_byte_size: Precision::Absent,
824 column_statistics: vec![ColumnStatistics {
825 distinct_count: Precision::Exact(3),
826 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
827 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
828 sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
829 null_count: Precision::Exact(2),
830 byte_size: Precision::Absent,
831 }],
832 };
833
834 let result = stats_cartesian_product(left, right);
835
836 let expected = Statistics {
837 num_rows: Precision::Absent,
838 total_byte_size: Precision::Absent,
839 column_statistics: vec![
840 ColumnStatistics {
841 distinct_count: Precision::Exact(5),
842 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
843 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
844 sum_value: Precision::Absent, null_count: Precision::Absent, byte_size: Precision::Absent,
847 },
848 ColumnStatistics {
849 distinct_count: Precision::Exact(1),
850 max_value: Precision::Exact(ScalarValue::from("x")),
851 min_value: Precision::Exact(ScalarValue::from("a")),
852 sum_value: Precision::Absent,
853 null_count: Precision::Absent, byte_size: Precision::Absent,
855 },
856 ColumnStatistics {
857 distinct_count: Precision::Exact(3),
858 max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
859 min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
860 sum_value: Precision::Exact(ScalarValue::Int64(Some(
861 20 * left_row_count as i64,
862 ))),
863 null_count: Precision::Exact(2 * left_row_count),
864 byte_size: Precision::Absent,
865 },
866 ],
867 };
868
869 assert_eq!(result, expected);
870 }
871
872 #[tokio::test]
873 async fn test_stats_cartesian_product_unsigned_sum_widens_to_u64() {
874 let left_row_count = 2;
875 let right_row_count = 3;
876
877 let left = Statistics {
878 num_rows: Precision::Exact(left_row_count),
879 total_byte_size: Precision::Exact(10),
880 column_statistics: vec![ColumnStatistics {
881 distinct_count: Precision::Exact(2),
882 max_value: Precision::Exact(ScalarValue::UInt32(Some(10))),
883 min_value: Precision::Exact(ScalarValue::UInt32(Some(1))),
884 sum_value: Precision::Exact(ScalarValue::UInt32(Some(7))),
885 null_count: Precision::Exact(0),
886 byte_size: Precision::Absent,
887 }],
888 };
889
890 let right = Statistics {
891 num_rows: Precision::Exact(right_row_count),
892 total_byte_size: Precision::Exact(10),
893 column_statistics: vec![ColumnStatistics {
894 distinct_count: Precision::Exact(3),
895 max_value: Precision::Exact(ScalarValue::UInt32(Some(12))),
896 min_value: Precision::Exact(ScalarValue::UInt32(Some(0))),
897 sum_value: Precision::Exact(ScalarValue::UInt32(Some(11))),
898 null_count: Precision::Exact(0),
899 byte_size: Precision::Absent,
900 }],
901 };
902
903 let result = stats_cartesian_product(left, right);
904
905 assert_eq!(
906 result.column_statistics[0].sum_value,
907 Precision::Exact(ScalarValue::UInt64(Some(21)))
908 );
909 assert_eq!(
910 result.column_statistics[1].sum_value,
911 Precision::Exact(ScalarValue::UInt64(Some(22)))
912 );
913 }
914
915 #[tokio::test]
916 async fn test_join() -> Result<()> {
917 let task_ctx = Arc::new(TaskContext::default());
918
919 let left = build_table_scan_i32(
920 ("a1", &vec![1, 2, 3]),
921 ("b1", &vec![4, 5, 6]),
922 ("c1", &vec![7, 8, 9]),
923 );
924 let right = build_table_scan_i32(
925 ("a2", &vec![10, 11]),
926 ("b2", &vec![12, 13]),
927 ("c2", &vec![14, 15]),
928 );
929
930 let (columns, batches, metrics) = join_collect(left, right, task_ctx).await?;
931
932 assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
933
934 assert_snapshot!(batches_to_sort_string(&batches), @r"
935 +----+----+----+----+----+----+
936 | a1 | b1 | c1 | a2 | b2 | c2 |
937 +----+----+----+----+----+----+
938 | 1 | 4 | 7 | 10 | 12 | 14 |
939 | 1 | 4 | 7 | 11 | 13 | 15 |
940 | 2 | 5 | 8 | 10 | 12 | 14 |
941 | 2 | 5 | 8 | 11 | 13 | 15 |
942 | 3 | 6 | 9 | 10 | 12 | 14 |
943 | 3 | 6 | 9 | 11 | 13 | 15 |
944 +----+----+----+----+----+----+
945 ");
946
947 assert_join_metrics!(metrics, 6);
948
949 Ok(())
950 }
951
952 #[tokio::test]
953 async fn test_overallocation() -> Result<()> {
954 let runtime = RuntimeEnvBuilder::new()
955 .with_memory_limit(100, 1.0)
956 .build_arc()?;
957 let task_ctx = TaskContext::default().with_runtime(runtime);
958 let task_ctx = Arc::new(task_ctx);
959
960 let left = build_table_scan_i32(
961 ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
962 ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
963 ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
964 );
965 let right = build_table_scan_i32(
966 ("a2", &vec![10, 11]),
967 ("b2", &vec![12, 13]),
968 ("c2", &vec![14, 15]),
969 );
970
971 let err = join_collect(left, right, task_ctx).await.unwrap_err();
972
973 assert_contains!(
974 err.to_string(),
975 "Resources exhausted: Additional allocation failed for CrossJoinExec with top memory consumers (across reservations) as:\n CrossJoinExec"
976 );
977
978 Ok(())
979 }
980
981 fn columns(schema: &Schema) -> Vec<String> {
983 schema.fields().iter().map(|f| f.name().clone()).collect()
984 }
985}