Skip to main content

datafusion_physical_plan/joins/
cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the cross join plan for loading the left side of the cross join
19//! and producing batches in parallel for the right partitions
20
21use std::{sync::Arc, task::Poll};
22
23use super::utils::{
24    BatchSplitter, BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer,
25    OnceAsync, OnceFut, StatefulStreamResult, adjust_right_output_partitioning,
26    reorder_output_after_swap,
27};
28use crate::execution_plan::{EmissionType, boundedness_from_children};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::{
31    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
32    physical_to_column_exprs,
33};
34use crate::stream::EmptyRecordBatchStream;
35use crate::{
36    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
37    ExecutionPlanProperties, PlanProperties, RecordBatchStream,
38    SendableRecordBatchStream, Statistics, check_if_same_properties, handle_state,
39};
40
41use arrow::array::{RecordBatch, RecordBatchOptions};
42use arrow::compute::concat_batches;
43use arrow::datatypes::{Fields, Schema, SchemaRef};
44use datafusion_common::stats::Precision;
45use datafusion_common::{
46    JoinType, Result, ScalarValue, assert_eq_or_internal_err, internal_err,
47};
48use datafusion_execution::TaskContext;
49use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
50use datafusion_physical_expr::equivalence::join_equivalence_properties;
51
52use async_trait::async_trait;
53use futures::{Stream, StreamExt, TryStreamExt, ready};
54
55/// Data of the left side that is buffered into memory
56#[derive(Debug)]
57struct JoinLeftData {
58    /// Single RecordBatch with all rows from the left side
59    merged_batch: RecordBatch,
60    /// Track memory reservation for merged_batch. Relies on drop
61    /// semantics to release reservation when JoinLeftData is dropped.
62    _reservation: MemoryReservation,
63}
64
65#[expect(rustdoc::private_intra_doc_links)]
66/// Cross Join Execution Plan
67///
68/// This operator is used when there are no predicates between two tables and
69/// returns the Cartesian product of the two tables.
70///
71/// Buffers the left input into memory and then streams batches from each
72/// partition on the right input combining them with the buffered left input
73/// to generate the output.
74///
75/// # Clone / Shared State
76///
77/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
78/// loading of the left side with the processing in each output stream.
79/// Therefore it can not be [`Clone`]
80#[derive(Debug)]
81pub struct CrossJoinExec {
82    /// left (build) side which gets loaded in memory
83    pub left: Arc<dyn ExecutionPlan>,
84    /// right (probe) side which are combined with left side
85    pub right: Arc<dyn ExecutionPlan>,
86    /// The schema once the join is applied
87    schema: SchemaRef,
88    /// Buffered copy of left (build) side in memory.
89    ///
90    /// This structure is *shared* across all output streams.
91    ///
92    /// Each output stream waits on the `OnceAsync` to signal the completion of
93    /// the left side loading.
94    left_fut: OnceAsync<JoinLeftData>,
95    /// Execution plan metrics
96    metrics: ExecutionPlanMetricsSet,
97    /// Properties such as schema, equivalence properties, ordering, partitioning, etc.
98    cache: Arc<PlanProperties>,
99}
100
101impl CrossJoinExec {
102    /// Create a new [CrossJoinExec].
103    pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
104        // left then right
105        let (all_columns, metadata) = {
106            let left_schema = left.schema();
107            let right_schema = right.schema();
108            let left_fields = left_schema.fields().iter();
109            let right_fields = right_schema.fields().iter();
110
111            let mut metadata = left_schema.metadata().clone();
112            metadata.extend(right_schema.metadata().clone());
113
114            (
115                left_fields.chain(right_fields).cloned().collect::<Fields>(),
116                metadata,
117            )
118        };
119
120        let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
121        let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap();
122
123        CrossJoinExec {
124            left,
125            right,
126            schema,
127            left_fut: Default::default(),
128            metrics: ExecutionPlanMetricsSet::default(),
129            cache: Arc::new(cache),
130        }
131    }
132
133    /// left (build) side which gets loaded in memory
134    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
135        &self.left
136    }
137
138    /// right side which gets combined with left side
139    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
140        &self.right
141    }
142
143    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
144    fn compute_properties(
145        left: &Arc<dyn ExecutionPlan>,
146        right: &Arc<dyn ExecutionPlan>,
147        schema: SchemaRef,
148    ) -> Result<PlanProperties> {
149        // Calculate equivalence properties
150        // TODO: Check equivalence properties of cross join, it may preserve
151        //       ordering in some cases.
152        let eq_properties = join_equivalence_properties(
153            left.equivalence_properties().clone(),
154            right.equivalence_properties().clone(),
155            &JoinType::Full,
156            schema,
157            &[false, false],
158            None,
159            &[],
160        )?;
161
162        // Get output partitioning:
163        // TODO: Optimize the cross join implementation to generate M * N
164        //       partitions.
165        let output_partitioning = adjust_right_output_partitioning(
166            right.output_partitioning(),
167            left.schema().fields.len(),
168        )?;
169
170        Ok(PlanProperties::new(
171            eq_properties,
172            output_partitioning,
173            EmissionType::Final,
174            boundedness_from_children([left, right]),
175        ))
176    }
177
178    /// Returns a new `ExecutionPlan` that computes the same join as this one,
179    /// with the left and right inputs swapped using the  specified
180    /// `partition_mode`.
181    ///
182    /// # Notes:
183    ///
184    /// This function should be called BEFORE inserting any repartitioning
185    /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`]
186    /// for more details.
187    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
188        let new_join =
189            CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
190        reorder_output_after_swap(
191            Arc::new(new_join),
192            &self.left.schema(),
193            &self.right.schema(),
194        )
195    }
196
197    fn with_new_children_and_same_properties(
198        &self,
199        mut children: Vec<Arc<dyn ExecutionPlan>>,
200    ) -> Self {
201        let left = children.swap_remove(0);
202        let right = children.swap_remove(0);
203
204        Self {
205            left,
206            right,
207            metrics: ExecutionPlanMetricsSet::new(),
208            left_fut: Default::default(),
209            cache: Arc::clone(&self.cache),
210            schema: Arc::clone(&self.schema),
211        }
212    }
213}
214
215/// Asynchronously collect the result of the left child
216async fn load_left_input(
217    stream: SendableRecordBatchStream,
218    metrics: BuildProbeJoinMetrics,
219    reservation: MemoryReservation,
220) -> Result<JoinLeftData> {
221    let left_schema = stream.schema();
222
223    // Load all batches and count the rows
224    let (batches, _metrics, reservation) = stream
225        .try_fold(
226            (Vec::new(), metrics, reservation),
227            |(mut batches, metrics, reservation), batch| async {
228                let batch_size = batch.get_array_memory_size();
229                // Reserve memory for incoming batch
230                reservation.try_grow(batch_size)?;
231                // Update metrics
232                metrics.build_mem_used.add(batch_size);
233                metrics.build_input_batches.add(1);
234                metrics.build_input_rows.add(batch.num_rows());
235                // Push batch to output
236                batches.push(batch);
237                Ok((batches, metrics, reservation))
238            },
239        )
240        .await?;
241
242    let merged_batch = concat_batches(&left_schema, &batches)?;
243
244    Ok(JoinLeftData {
245        merged_batch,
246        _reservation: reservation,
247    })
248}
249
250impl DisplayAs for CrossJoinExec {
251    fn fmt_as(
252        &self,
253        t: DisplayFormatType,
254        f: &mut std::fmt::Formatter,
255    ) -> std::fmt::Result {
256        match t {
257            DisplayFormatType::Default | DisplayFormatType::Verbose => {
258                write!(f, "CrossJoinExec")
259            }
260            DisplayFormatType::TreeRender => {
261                // no extra info to display
262                Ok(())
263            }
264        }
265    }
266}
267
268impl ExecutionPlan for CrossJoinExec {
269    fn name(&self) -> &'static str {
270        "CrossJoinExec"
271    }
272
273    fn properties(&self) -> &Arc<PlanProperties> {
274        &self.cache
275    }
276
277    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
278        vec![&self.left, &self.right]
279    }
280
281    fn metrics(&self) -> Option<MetricsSet> {
282        Some(self.metrics.clone_inner())
283    }
284
285    fn with_new_children(
286        self: Arc<Self>,
287        children: Vec<Arc<dyn ExecutionPlan>>,
288    ) -> Result<Arc<dyn ExecutionPlan>> {
289        check_if_same_properties!(self, children);
290        Ok(Arc::new(CrossJoinExec::new(
291            Arc::clone(&children[0]),
292            Arc::clone(&children[1]),
293        )))
294    }
295
296    fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
297        let new_exec = CrossJoinExec {
298            left: Arc::clone(&self.left),
299            right: Arc::clone(&self.right),
300            schema: Arc::clone(&self.schema),
301            left_fut: Default::default(), // reset the build side!
302            metrics: ExecutionPlanMetricsSet::default(),
303            cache: Arc::clone(&self.cache),
304        };
305        Ok(Arc::new(new_exec))
306    }
307
308    fn required_input_distribution(&self) -> Vec<Distribution> {
309        vec![
310            Distribution::SinglePartition,
311            Distribution::UnspecifiedDistribution,
312        ]
313    }
314
315    fn execute(
316        &self,
317        partition: usize,
318        context: Arc<TaskContext>,
319    ) -> Result<SendableRecordBatchStream> {
320        assert_eq_or_internal_err!(
321            self.left.output_partitioning().partition_count(),
322            1,
323            "Invalid CrossJoinExec, the output partition count of the left child must be 1,\
324                 consider using CoalescePartitionsExec or the EnforceDistribution rule"
325        );
326
327        let stream = self.right.execute(partition, Arc::clone(&context))?;
328
329        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
330
331        // Initialization of operator-level reservation
332        let reservation =
333            MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
334
335        let batch_size = context.session_config().batch_size();
336        let enforce_batch_size_in_joins =
337            context.session_config().enforce_batch_size_in_joins();
338
339        let left_fut = self.left_fut.try_once(|| {
340            let left_stream = self.left.execute(0, context)?;
341
342            Ok(load_left_input(
343                left_stream,
344                join_metrics.clone(),
345                reservation,
346            ))
347        })?;
348
349        if enforce_batch_size_in_joins {
350            Ok(Box::pin(CrossJoinStream {
351                schema: Arc::clone(&self.schema),
352                left_fut,
353                right: stream,
354                left_index: 0,
355                join_metrics,
356                state: CrossJoinStreamState::WaitBuildSide,
357                left_data: RecordBatch::new_empty(self.left().schema()),
358                batch_transformer: BatchSplitter::new(batch_size),
359            }))
360        } else {
361            Ok(Box::pin(CrossJoinStream {
362                schema: Arc::clone(&self.schema),
363                left_fut,
364                right: stream,
365                left_index: 0,
366                join_metrics,
367                state: CrossJoinStreamState::WaitBuildSide,
368                left_data: RecordBatch::new_empty(self.left().schema()),
369                batch_transformer: NoopBatchTransformer::new(),
370            }))
371        }
372    }
373
374    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
375        // Get the all partitions statistics of the left
376        let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(None)?);
377        let right_stats =
378            Arc::unwrap_or_clone(self.right.partition_statistics(partition)?);
379
380        Ok(Arc::new(stats_cartesian_product(left_stats, right_stats)))
381    }
382
383    /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done,
384    /// it returns the new swapped version having the [`CrossJoinExec`] as the top plan.
385    /// Otherwise, it returns None.
386    fn try_swapping_with_projection(
387        &self,
388        projection: &ProjectionExec,
389    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
390        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
391        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
392        else {
393            return Ok(None);
394        };
395
396        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
397            self.left().schema().fields().len(),
398            &projection_as_columns,
399        );
400
401        if !join_allows_pushdown(
402            &projection_as_columns,
403            &self.schema(),
404            far_right_left_col_ind,
405            far_left_right_col_ind,
406        ) {
407            return Ok(None);
408        }
409
410        let (new_left, new_right) = new_join_children(
411            &projection_as_columns,
412            far_right_left_col_ind,
413            far_left_right_col_ind,
414            self.left(),
415            self.right(),
416        )?;
417
418        Ok(Some(Arc::new(CrossJoinExec::new(
419            Arc::new(new_left),
420            Arc::new(new_right),
421        ))))
422    }
423}
424
425/// [left/right]_col_count are required in case the column statistics are None
426fn stats_cartesian_product(
427    left_stats: Statistics,
428    right_stats: Statistics,
429) -> Statistics {
430    let left_row_count = left_stats.num_rows;
431    let right_row_count = right_stats.num_rows;
432
433    // calculate global stats
434    let num_rows = left_row_count.multiply(&right_row_count);
435    // the result size is two times a*b because you have the columns of both left and right
436    let total_byte_size = left_stats
437        .total_byte_size
438        .multiply(&right_stats.total_byte_size)
439        .multiply(&Precision::Exact(2));
440
441    let left_col_stats = left_stats.column_statistics;
442    let right_col_stats = right_stats.column_statistics;
443
444    // the null counts must be multiplied by the row counts of the other side (if defined)
445    // Min, max and distinct_count on the other hand are invariants.
446    let cross_join_stats = left_col_stats
447        .into_iter()
448        .map(|s| {
449            let widened_sum = s.sum_value.cast_to_sum_type();
450            ColumnStatistics {
451                null_count: s.null_count.multiply(&right_row_count),
452                distinct_count: s.distinct_count,
453                min_value: s.min_value,
454                max_value: s.max_value,
455                sum_value: widened_sum
456                    .get_value()
457                    // Cast the row count into the same type as any existing sum value
458                    .and_then(|v| {
459                        Precision::<ScalarValue>::from(right_row_count)
460                            .cast_to(&v.data_type())
461                            .ok()
462                    })
463                    .map(|row_count| widened_sum.multiply(&row_count))
464                    .unwrap_or(Precision::Absent),
465                byte_size: Precision::Absent,
466            }
467        })
468        .chain(right_col_stats.into_iter().map(|s| {
469            let widened_sum = s.sum_value.cast_to_sum_type();
470            ColumnStatistics {
471                null_count: s.null_count.multiply(&left_row_count),
472                distinct_count: s.distinct_count,
473                min_value: s.min_value,
474                max_value: s.max_value,
475                sum_value: widened_sum
476                    .get_value()
477                    // Cast the row count into the same type as any existing sum value
478                    .and_then(|v| {
479                        Precision::<ScalarValue>::from(left_row_count)
480                            .cast_to(&v.data_type())
481                            .ok()
482                    })
483                    .map(|row_count| widened_sum.multiply(&row_count))
484                    .unwrap_or(Precision::Absent),
485                byte_size: Precision::Absent,
486            }
487        }))
488        .collect();
489
490    Statistics {
491        num_rows,
492        total_byte_size,
493        column_statistics: cross_join_stats,
494    }
495}
496
497/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
498struct CrossJoinStream<T> {
499    /// Input schema
500    schema: Arc<Schema>,
501    /// Future for data from left side
502    left_fut: OnceFut<JoinLeftData>,
503    /// Right side stream
504    right: SendableRecordBatchStream,
505    /// Current value on the left
506    left_index: usize,
507    /// Join execution metrics
508    join_metrics: BuildProbeJoinMetrics,
509    /// State of the stream
510    state: CrossJoinStreamState,
511    /// Left data (copy of the entire buffered left side)
512    left_data: RecordBatch,
513    /// Batch transformer
514    batch_transformer: T,
515}
516
517impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
518    fn schema(&self) -> SchemaRef {
519        Arc::clone(&self.schema)
520    }
521}
522
523/// Represents states of CrossJoinStream
524enum CrossJoinStreamState {
525    WaitBuildSide,
526    FetchProbeBatch,
527    /// Holds the currently processed right side batch
528    BuildBatches(RecordBatch),
529}
530
531impl CrossJoinStreamState {
532    /// Tries to extract RecordBatch from CrossJoinStreamState enum.
533    /// Returns an error if state is not BuildBatches state.
534    fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
535        match self {
536            CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
537            _ => internal_err!("Expected RecordBatch in BuildBatches state"),
538        }
539    }
540}
541
542fn build_batch(
543    left_index: usize,
544    batch: &RecordBatch,
545    left_data: &RecordBatch,
546    schema: &Schema,
547) -> Result<RecordBatch> {
548    // Repeat value on the left n times
549    let arrays = left_data
550        .columns()
551        .iter()
552        .map(|arr| {
553            let scalar = ScalarValue::try_from_array(arr, left_index)?;
554            scalar.to_array_of_size(batch.num_rows())
555        })
556        .collect::<Result<Vec<_>>>()?;
557
558    RecordBatch::try_new_with_options(
559        Arc::new(schema.clone()),
560        arrays
561            .iter()
562            .chain(batch.columns().iter())
563            .cloned()
564            .collect(),
565        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
566    )
567    .map_err(Into::into)
568}
569
570#[async_trait]
571impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
572    type Item = Result<RecordBatch>;
573
574    fn poll_next(
575        mut self: std::pin::Pin<&mut Self>,
576        cx: &mut std::task::Context<'_>,
577    ) -> Poll<Option<Self::Item>> {
578        self.poll_next_impl(cx)
579    }
580}
581
582impl<T: BatchTransformer> CrossJoinStream<T> {
583    /// Separate implementation function that unpins the [`CrossJoinStream`] so
584    /// that partial borrows work correctly
585    fn poll_next_impl(
586        &mut self,
587        cx: &mut std::task::Context<'_>,
588    ) -> Poll<Option<Result<RecordBatch>>> {
589        loop {
590            return match self.state {
591                CrossJoinStreamState::WaitBuildSide => {
592                    handle_state!(ready!(self.collect_build_side(cx)))
593                }
594                CrossJoinStreamState::FetchProbeBatch => {
595                    handle_state!(ready!(self.fetch_probe_batch(cx)))
596                }
597                CrossJoinStreamState::BuildBatches(_) => {
598                    let poll = handle_state!(self.build_batches());
599                    self.join_metrics.baseline.record_poll(poll)
600                }
601            };
602        }
603    }
604
605    /// Collects build (left) side of the join into the state. In case of an empty build batch,
606    /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
607    fn collect_build_side(
608        &mut self,
609        cx: &mut std::task::Context<'_>,
610    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
611        let build_timer = self.join_metrics.build_time.timer();
612        let left_data = match ready!(self.left_fut.get(cx)) {
613            Ok(left_data) => left_data,
614            Err(e) => return Poll::Ready(Err(e)),
615        };
616        build_timer.done();
617
618        let left_data = left_data.merged_batch.clone();
619        let result = if left_data.num_rows() == 0 {
620            StatefulStreamResult::Ready(None)
621        } else {
622            self.left_data = left_data;
623            self.state = CrossJoinStreamState::FetchProbeBatch;
624            StatefulStreamResult::Continue
625        };
626        Poll::Ready(Ok(result))
627    }
628
629    /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
630    /// Then, the state is updated to build result batches.
631    fn fetch_probe_batch(
632        &mut self,
633        cx: &mut std::task::Context<'_>,
634    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
635        self.left_index = 0;
636        let right_data = match ready!(self.right.poll_next_unpin(cx)) {
637            Some(Ok(right_data)) => right_data,
638            Some(Err(e)) => return Poll::Ready(Err(e)),
639            None => {
640                // Release the right (probe) input pipeline's resources.
641                let right_schema = self.right.schema();
642                self.right = Box::pin(EmptyRecordBatchStream::new(right_schema));
643                return Poll::Ready(Ok(StatefulStreamResult::Ready(None)));
644            }
645        };
646        self.join_metrics.input_batches.add(1);
647        self.join_metrics.input_rows.add(right_data.num_rows());
648
649        self.state = CrossJoinStreamState::BuildBatches(right_data);
650        Poll::Ready(Ok(StatefulStreamResult::Continue))
651    }
652
653    /// Joins the indexed row of left data with the current probe batch.
654    /// If all the results are produced, the state is set to fetch new probe batch.
655    fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
656        let right_batch = self.state.try_as_record_batch()?;
657        if self.left_index < self.left_data.num_rows() {
658            match self.batch_transformer.next() {
659                None => {
660                    let join_timer = self.join_metrics.join_time.timer();
661                    let result = build_batch(
662                        self.left_index,
663                        right_batch,
664                        &self.left_data,
665                        &self.schema,
666                    );
667                    join_timer.done();
668
669                    self.batch_transformer.set_batch(result?);
670                }
671                Some((batch, last)) => {
672                    if last {
673                        self.left_index += 1;
674                    }
675
676                    return Ok(StatefulStreamResult::Ready(Some(batch)));
677                }
678            }
679        } else {
680            self.state = CrossJoinStreamState::FetchProbeBatch;
681        }
682        Ok(StatefulStreamResult::Continue)
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use crate::common;
690    use crate::test::{assert_join_metrics, build_table_scan_i32};
691
692    use datafusion_common::{assert_contains, test_util::batches_to_sort_string};
693    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
694    use insta::assert_snapshot;
695
696    async fn join_collect(
697        left: Arc<dyn ExecutionPlan>,
698        right: Arc<dyn ExecutionPlan>,
699        context: Arc<TaskContext>,
700    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
701        let join = CrossJoinExec::new(left, right);
702        let columns_header = columns(&join.schema());
703
704        let stream = join.execute(0, context)?;
705        let batches = common::collect(stream).await?;
706        let metrics = join.metrics().unwrap();
707
708        Ok((columns_header, batches, metrics))
709    }
710
711    #[tokio::test]
712    async fn test_stats_cartesian_product() {
713        let left_row_count = 11;
714        let left_bytes = 23;
715        let right_row_count = 7;
716        let right_bytes = 27;
717
718        let left = Statistics {
719            num_rows: Precision::Exact(left_row_count),
720            total_byte_size: Precision::Exact(left_bytes),
721            column_statistics: vec![
722                ColumnStatistics {
723                    distinct_count: Precision::Exact(5),
724                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
725                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
726                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
727                    null_count: Precision::Exact(0),
728                    byte_size: Precision::Absent,
729                },
730                ColumnStatistics {
731                    distinct_count: Precision::Exact(1),
732                    max_value: Precision::Exact(ScalarValue::from("x")),
733                    min_value: Precision::Exact(ScalarValue::from("a")),
734                    sum_value: Precision::Absent,
735                    null_count: Precision::Exact(3),
736                    byte_size: Precision::Absent,
737                },
738            ],
739        };
740
741        let right = Statistics {
742            num_rows: Precision::Exact(right_row_count),
743            total_byte_size: Precision::Exact(right_bytes),
744            column_statistics: vec![ColumnStatistics {
745                distinct_count: Precision::Exact(3),
746                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
747                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
748                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
749                null_count: Precision::Exact(2),
750                byte_size: Precision::Absent,
751            }],
752        };
753
754        let result = stats_cartesian_product(left, right);
755
756        let expected = Statistics {
757            num_rows: Precision::Exact(left_row_count * right_row_count),
758            total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
759            column_statistics: vec![
760                ColumnStatistics {
761                    distinct_count: Precision::Exact(5),
762                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
763                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
764                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
765                        42 * right_row_count as i64,
766                    ))),
767                    null_count: Precision::Exact(0),
768                    byte_size: Precision::Absent,
769                },
770                ColumnStatistics {
771                    distinct_count: Precision::Exact(1),
772                    max_value: Precision::Exact(ScalarValue::from("x")),
773                    min_value: Precision::Exact(ScalarValue::from("a")),
774                    sum_value: Precision::Absent,
775                    null_count: Precision::Exact(3 * right_row_count),
776                    byte_size: Precision::Absent,
777                },
778                ColumnStatistics {
779                    distinct_count: Precision::Exact(3),
780                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
781                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
782                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
783                        20 * left_row_count as i64,
784                    ))),
785                    null_count: Precision::Exact(2 * left_row_count),
786                    byte_size: Precision::Absent,
787                },
788            ],
789        };
790
791        assert_eq!(result, expected);
792    }
793
794    #[tokio::test]
795    async fn test_stats_cartesian_product_with_unknown_size() {
796        let left_row_count = 11;
797
798        let left = Statistics {
799            num_rows: Precision::Exact(left_row_count),
800            total_byte_size: Precision::Exact(23),
801            column_statistics: vec![
802                ColumnStatistics {
803                    distinct_count: Precision::Exact(5),
804                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
805                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
806                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
807                    null_count: Precision::Exact(0),
808                    byte_size: Precision::Absent,
809                },
810                ColumnStatistics {
811                    distinct_count: Precision::Exact(1),
812                    max_value: Precision::Exact(ScalarValue::from("x")),
813                    min_value: Precision::Exact(ScalarValue::from("a")),
814                    sum_value: Precision::Absent,
815                    null_count: Precision::Exact(3),
816                    byte_size: Precision::Absent,
817                },
818            ],
819        };
820
821        let right = Statistics {
822            num_rows: Precision::Absent,
823            total_byte_size: Precision::Absent,
824            column_statistics: vec![ColumnStatistics {
825                distinct_count: Precision::Exact(3),
826                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
827                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
828                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
829                null_count: Precision::Exact(2),
830                byte_size: Precision::Absent,
831            }],
832        };
833
834        let result = stats_cartesian_product(left, right);
835
836        let expected = Statistics {
837            num_rows: Precision::Absent,
838            total_byte_size: Precision::Absent,
839            column_statistics: vec![
840                ColumnStatistics {
841                    distinct_count: Precision::Exact(5),
842                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
843                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
844                    sum_value: Precision::Absent, // we don't know the row count on the right
845                    null_count: Precision::Absent, // we don't know the row count on the right
846                    byte_size: Precision::Absent,
847                },
848                ColumnStatistics {
849                    distinct_count: Precision::Exact(1),
850                    max_value: Precision::Exact(ScalarValue::from("x")),
851                    min_value: Precision::Exact(ScalarValue::from("a")),
852                    sum_value: Precision::Absent,
853                    null_count: Precision::Absent, // we don't know the row count on the right
854                    byte_size: Precision::Absent,
855                },
856                ColumnStatistics {
857                    distinct_count: Precision::Exact(3),
858                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
859                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
860                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
861                        20 * left_row_count as i64,
862                    ))),
863                    null_count: Precision::Exact(2 * left_row_count),
864                    byte_size: Precision::Absent,
865                },
866            ],
867        };
868
869        assert_eq!(result, expected);
870    }
871
872    #[tokio::test]
873    async fn test_stats_cartesian_product_unsigned_sum_widens_to_u64() {
874        let left_row_count = 2;
875        let right_row_count = 3;
876
877        let left = Statistics {
878            num_rows: Precision::Exact(left_row_count),
879            total_byte_size: Precision::Exact(10),
880            column_statistics: vec![ColumnStatistics {
881                distinct_count: Precision::Exact(2),
882                max_value: Precision::Exact(ScalarValue::UInt32(Some(10))),
883                min_value: Precision::Exact(ScalarValue::UInt32(Some(1))),
884                sum_value: Precision::Exact(ScalarValue::UInt32(Some(7))),
885                null_count: Precision::Exact(0),
886                byte_size: Precision::Absent,
887            }],
888        };
889
890        let right = Statistics {
891            num_rows: Precision::Exact(right_row_count),
892            total_byte_size: Precision::Exact(10),
893            column_statistics: vec![ColumnStatistics {
894                distinct_count: Precision::Exact(3),
895                max_value: Precision::Exact(ScalarValue::UInt32(Some(12))),
896                min_value: Precision::Exact(ScalarValue::UInt32(Some(0))),
897                sum_value: Precision::Exact(ScalarValue::UInt32(Some(11))),
898                null_count: Precision::Exact(0),
899                byte_size: Precision::Absent,
900            }],
901        };
902
903        let result = stats_cartesian_product(left, right);
904
905        assert_eq!(
906            result.column_statistics[0].sum_value,
907            Precision::Exact(ScalarValue::UInt64(Some(21)))
908        );
909        assert_eq!(
910            result.column_statistics[1].sum_value,
911            Precision::Exact(ScalarValue::UInt64(Some(22)))
912        );
913    }
914
915    #[tokio::test]
916    async fn test_join() -> Result<()> {
917        let task_ctx = Arc::new(TaskContext::default());
918
919        let left = build_table_scan_i32(
920            ("a1", &vec![1, 2, 3]),
921            ("b1", &vec![4, 5, 6]),
922            ("c1", &vec![7, 8, 9]),
923        );
924        let right = build_table_scan_i32(
925            ("a2", &vec![10, 11]),
926            ("b2", &vec![12, 13]),
927            ("c2", &vec![14, 15]),
928        );
929
930        let (columns, batches, metrics) = join_collect(left, right, task_ctx).await?;
931
932        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
933
934        assert_snapshot!(batches_to_sort_string(&batches), @r"
935        +----+----+----+----+----+----+
936        | a1 | b1 | c1 | a2 | b2 | c2 |
937        +----+----+----+----+----+----+
938        | 1  | 4  | 7  | 10 | 12 | 14 |
939        | 1  | 4  | 7  | 11 | 13 | 15 |
940        | 2  | 5  | 8  | 10 | 12 | 14 |
941        | 2  | 5  | 8  | 11 | 13 | 15 |
942        | 3  | 6  | 9  | 10 | 12 | 14 |
943        | 3  | 6  | 9  | 11 | 13 | 15 |
944        +----+----+----+----+----+----+
945        ");
946
947        assert_join_metrics!(metrics, 6);
948
949        Ok(())
950    }
951
952    #[tokio::test]
953    async fn test_overallocation() -> Result<()> {
954        let runtime = RuntimeEnvBuilder::new()
955            .with_memory_limit(100, 1.0)
956            .build_arc()?;
957        let task_ctx = TaskContext::default().with_runtime(runtime);
958        let task_ctx = Arc::new(task_ctx);
959
960        let left = build_table_scan_i32(
961            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
962            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
963            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
964        );
965        let right = build_table_scan_i32(
966            ("a2", &vec![10, 11]),
967            ("b2", &vec![12, 13]),
968            ("c2", &vec![14, 15]),
969        );
970
971        let err = join_collect(left, right, task_ctx).await.unwrap_err();
972
973        assert_contains!(
974            err.to_string(),
975            "Resources exhausted: Additional allocation failed for CrossJoinExec with top memory consumers (across reservations) as:\n  CrossJoinExec"
976        );
977
978        Ok(())
979    }
980
981    /// Returns the column names on the schema
982    fn columns(schema: &Schema) -> Vec<String> {
983        schema.fields().iter().map(|f| f.name().clone()).collect()
984    }
985}