Skip to main content

datafusion_physical_plan/joins/
cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the cross join plan for loading the left side of the cross join
19//! and producing batches in parallel for the right partitions
20
21use std::{any::Any, sync::Arc, task::Poll};
22
23use super::utils::{
24    BatchSplitter, BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer,
25    OnceAsync, OnceFut, StatefulStreamResult, adjust_right_output_partitioning,
26    reorder_output_after_swap,
27};
28use crate::execution_plan::{EmissionType, boundedness_from_children};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::{
31    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
32    physical_to_column_exprs,
33};
34use crate::{
35    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
36    ExecutionPlanProperties, PlanProperties, RecordBatchStream,
37    SendableRecordBatchStream, Statistics, handle_state,
38};
39
40use arrow::array::{RecordBatch, RecordBatchOptions};
41use arrow::compute::concat_batches;
42use arrow::datatypes::{Fields, Schema, SchemaRef};
43use datafusion_common::stats::Precision;
44use datafusion_common::{
45    JoinType, Result, ScalarValue, assert_eq_or_internal_err, internal_err,
46};
47use datafusion_execution::TaskContext;
48use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
49use datafusion_physical_expr::equivalence::join_equivalence_properties;
50
51use async_trait::async_trait;
52use futures::{Stream, StreamExt, TryStreamExt, ready};
53
54/// Data of the left side that is buffered into memory
55#[derive(Debug)]
56struct JoinLeftData {
57    /// Single RecordBatch with all rows from the left side
58    merged_batch: RecordBatch,
59    /// Track memory reservation for merged_batch. Relies on drop
60    /// semantics to release reservation when JoinLeftData is dropped.
61    _reservation: MemoryReservation,
62}
63
64#[expect(rustdoc::private_intra_doc_links)]
65/// Cross Join Execution Plan
66///
67/// This operator is used when there are no predicates between two tables and
68/// returns the Cartesian product of the two tables.
69///
70/// Buffers the left input into memory and then streams batches from each
71/// partition on the right input combining them with the buffered left input
72/// to generate the output.
73///
74/// # Clone / Shared State
75///
76/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
77/// loading of the left side with the processing in each output stream.
78/// Therefore it can not be [`Clone`]
79#[derive(Debug)]
80pub struct CrossJoinExec {
81    /// left (build) side which gets loaded in memory
82    pub left: Arc<dyn ExecutionPlan>,
83    /// right (probe) side which are combined with left side
84    pub right: Arc<dyn ExecutionPlan>,
85    /// The schema once the join is applied
86    schema: SchemaRef,
87    /// Buffered copy of left (build) side in memory.
88    ///
89    /// This structure is *shared* across all output streams.
90    ///
91    /// Each output stream waits on the `OnceAsync` to signal the completion of
92    /// the left side loading.
93    left_fut: OnceAsync<JoinLeftData>,
94    /// Execution plan metrics
95    metrics: ExecutionPlanMetricsSet,
96    /// Properties such as schema, equivalence properties, ordering, partitioning, etc.
97    cache: PlanProperties,
98}
99
100impl CrossJoinExec {
101    /// Create a new [CrossJoinExec].
102    pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
103        // left then right
104        let (all_columns, metadata) = {
105            let left_schema = left.schema();
106            let right_schema = right.schema();
107            let left_fields = left_schema.fields().iter();
108            let right_fields = right_schema.fields().iter();
109
110            let mut metadata = left_schema.metadata().clone();
111            metadata.extend(right_schema.metadata().clone());
112
113            (
114                left_fields.chain(right_fields).cloned().collect::<Fields>(),
115                metadata,
116            )
117        };
118
119        let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
120        let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap();
121
122        CrossJoinExec {
123            left,
124            right,
125            schema,
126            left_fut: Default::default(),
127            metrics: ExecutionPlanMetricsSet::default(),
128            cache,
129        }
130    }
131
132    /// left (build) side which gets loaded in memory
133    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
134        &self.left
135    }
136
137    /// right side which gets combined with left side
138    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
139        &self.right
140    }
141
142    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
143    fn compute_properties(
144        left: &Arc<dyn ExecutionPlan>,
145        right: &Arc<dyn ExecutionPlan>,
146        schema: SchemaRef,
147    ) -> Result<PlanProperties> {
148        // Calculate equivalence properties
149        // TODO: Check equivalence properties of cross join, it may preserve
150        //       ordering in some cases.
151        let eq_properties = join_equivalence_properties(
152            left.equivalence_properties().clone(),
153            right.equivalence_properties().clone(),
154            &JoinType::Full,
155            schema,
156            &[false, false],
157            None,
158            &[],
159        )?;
160
161        // Get output partitioning:
162        // TODO: Optimize the cross join implementation to generate M * N
163        //       partitions.
164        let output_partitioning = adjust_right_output_partitioning(
165            right.output_partitioning(),
166            left.schema().fields.len(),
167        )?;
168
169        Ok(PlanProperties::new(
170            eq_properties,
171            output_partitioning,
172            EmissionType::Final,
173            boundedness_from_children([left, right]),
174        ))
175    }
176
177    /// Returns a new `ExecutionPlan` that computes the same join as this one,
178    /// with the left and right inputs swapped using the  specified
179    /// `partition_mode`.
180    ///
181    /// # Notes:
182    ///
183    /// This function should be called BEFORE inserting any repartitioning
184    /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`]
185    /// for more details.
186    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
187        let new_join =
188            CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
189        reorder_output_after_swap(
190            Arc::new(new_join),
191            &self.left.schema(),
192            &self.right.schema(),
193        )
194    }
195}
196
197/// Asynchronously collect the result of the left child
198async fn load_left_input(
199    stream: SendableRecordBatchStream,
200    metrics: BuildProbeJoinMetrics,
201    reservation: MemoryReservation,
202) -> Result<JoinLeftData> {
203    let left_schema = stream.schema();
204
205    // Load all batches and count the rows
206    let (batches, _metrics, reservation) = stream
207        .try_fold(
208            (Vec::new(), metrics, reservation),
209            |(mut batches, metrics, mut reservation), batch| async {
210                let batch_size = batch.get_array_memory_size();
211                // Reserve memory for incoming batch
212                reservation.try_grow(batch_size)?;
213                // Update metrics
214                metrics.build_mem_used.add(batch_size);
215                metrics.build_input_batches.add(1);
216                metrics.build_input_rows.add(batch.num_rows());
217                // Push batch to output
218                batches.push(batch);
219                Ok((batches, metrics, reservation))
220            },
221        )
222        .await?;
223
224    let merged_batch = concat_batches(&left_schema, &batches)?;
225
226    Ok(JoinLeftData {
227        merged_batch,
228        _reservation: reservation,
229    })
230}
231
232impl DisplayAs for CrossJoinExec {
233    fn fmt_as(
234        &self,
235        t: DisplayFormatType,
236        f: &mut std::fmt::Formatter,
237    ) -> std::fmt::Result {
238        match t {
239            DisplayFormatType::Default | DisplayFormatType::Verbose => {
240                write!(f, "CrossJoinExec")
241            }
242            DisplayFormatType::TreeRender => {
243                // no extra info to display
244                Ok(())
245            }
246        }
247    }
248}
249
250impl ExecutionPlan for CrossJoinExec {
251    fn name(&self) -> &'static str {
252        "CrossJoinExec"
253    }
254
255    fn as_any(&self) -> &dyn Any {
256        self
257    }
258
259    fn properties(&self) -> &PlanProperties {
260        &self.cache
261    }
262
263    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
264        vec![&self.left, &self.right]
265    }
266
267    fn metrics(&self) -> Option<MetricsSet> {
268        Some(self.metrics.clone_inner())
269    }
270
271    fn with_new_children(
272        self: Arc<Self>,
273        children: Vec<Arc<dyn ExecutionPlan>>,
274    ) -> Result<Arc<dyn ExecutionPlan>> {
275        Ok(Arc::new(CrossJoinExec::new(
276            Arc::clone(&children[0]),
277            Arc::clone(&children[1]),
278        )))
279    }
280
281    fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
282        let new_exec = CrossJoinExec {
283            left: Arc::clone(&self.left),
284            right: Arc::clone(&self.right),
285            schema: Arc::clone(&self.schema),
286            left_fut: Default::default(), // reset the build side!
287            metrics: ExecutionPlanMetricsSet::default(),
288            cache: self.cache.clone(),
289        };
290        Ok(Arc::new(new_exec))
291    }
292
293    fn required_input_distribution(&self) -> Vec<Distribution> {
294        vec![
295            Distribution::SinglePartition,
296            Distribution::UnspecifiedDistribution,
297        ]
298    }
299
300    fn execute(
301        &self,
302        partition: usize,
303        context: Arc<TaskContext>,
304    ) -> Result<SendableRecordBatchStream> {
305        assert_eq_or_internal_err!(
306            self.left.output_partitioning().partition_count(),
307            1,
308            "Invalid CrossJoinExec, the output partition count of the left child must be 1,\
309                 consider using CoalescePartitionsExec or the EnforceDistribution rule"
310        );
311
312        let stream = self.right.execute(partition, Arc::clone(&context))?;
313
314        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
315
316        // Initialization of operator-level reservation
317        let reservation =
318            MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
319
320        let batch_size = context.session_config().batch_size();
321        let enforce_batch_size_in_joins =
322            context.session_config().enforce_batch_size_in_joins();
323
324        let left_fut = self.left_fut.try_once(|| {
325            let left_stream = self.left.execute(0, context)?;
326
327            Ok(load_left_input(
328                left_stream,
329                join_metrics.clone(),
330                reservation,
331            ))
332        })?;
333
334        if enforce_batch_size_in_joins {
335            Ok(Box::pin(CrossJoinStream {
336                schema: Arc::clone(&self.schema),
337                left_fut,
338                right: stream,
339                left_index: 0,
340                join_metrics,
341                state: CrossJoinStreamState::WaitBuildSide,
342                left_data: RecordBatch::new_empty(self.left().schema()),
343                batch_transformer: BatchSplitter::new(batch_size),
344            }))
345        } else {
346            Ok(Box::pin(CrossJoinStream {
347                schema: Arc::clone(&self.schema),
348                left_fut,
349                right: stream,
350                left_index: 0,
351                join_metrics,
352                state: CrossJoinStreamState::WaitBuildSide,
353                left_data: RecordBatch::new_empty(self.left().schema()),
354                batch_transformer: NoopBatchTransformer::new(),
355            }))
356        }
357    }
358
359    fn statistics(&self) -> Result<Statistics> {
360        self.partition_statistics(None)
361    }
362
363    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
364        // Get the all partitions statistics of the left
365        let left_stats = self.left.partition_statistics(None)?;
366        let right_stats = self.right.partition_statistics(partition)?;
367
368        Ok(stats_cartesian_product(left_stats, right_stats))
369    }
370
371    /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done,
372    /// it returns the new swapped version having the [`CrossJoinExec`] as the top plan.
373    /// Otherwise, it returns None.
374    fn try_swapping_with_projection(
375        &self,
376        projection: &ProjectionExec,
377    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
378        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
379        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
380        else {
381            return Ok(None);
382        };
383
384        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
385            self.left().schema().fields().len(),
386            &projection_as_columns,
387        );
388
389        if !join_allows_pushdown(
390            &projection_as_columns,
391            &self.schema(),
392            far_right_left_col_ind,
393            far_left_right_col_ind,
394        ) {
395            return Ok(None);
396        }
397
398        let (new_left, new_right) = new_join_children(
399            &projection_as_columns,
400            far_right_left_col_ind,
401            far_left_right_col_ind,
402            self.left(),
403            self.right(),
404        )?;
405
406        Ok(Some(Arc::new(CrossJoinExec::new(
407            Arc::new(new_left),
408            Arc::new(new_right),
409        ))))
410    }
411}
412
413/// [left/right]_col_count are required in case the column statistics are None
414fn stats_cartesian_product(
415    left_stats: Statistics,
416    right_stats: Statistics,
417) -> Statistics {
418    let left_row_count = left_stats.num_rows;
419    let right_row_count = right_stats.num_rows;
420
421    // calculate global stats
422    let num_rows = left_row_count.multiply(&right_row_count);
423    // the result size is two times a*b because you have the columns of both left and right
424    let total_byte_size = left_stats
425        .total_byte_size
426        .multiply(&right_stats.total_byte_size)
427        .multiply(&Precision::Exact(2));
428
429    let left_col_stats = left_stats.column_statistics;
430    let right_col_stats = right_stats.column_statistics;
431
432    // the null counts must be multiplied by the row counts of the other side (if defined)
433    // Min, max and distinct_count on the other hand are invariants.
434    let cross_join_stats = left_col_stats
435        .into_iter()
436        .map(|s| ColumnStatistics {
437            null_count: s.null_count.multiply(&right_row_count),
438            distinct_count: s.distinct_count,
439            min_value: s.min_value,
440            max_value: s.max_value,
441            sum_value: s
442                .sum_value
443                .get_value()
444                // Cast the row count into the same type as any existing sum value
445                .and_then(|v| {
446                    Precision::<ScalarValue>::from(right_row_count)
447                        .cast_to(&v.data_type())
448                        .ok()
449                })
450                .map(|row_count| s.sum_value.multiply(&row_count))
451                .unwrap_or(Precision::Absent),
452            byte_size: Precision::Absent,
453        })
454        .chain(right_col_stats.into_iter().map(|s| {
455            ColumnStatistics {
456                null_count: s.null_count.multiply(&left_row_count),
457                distinct_count: s.distinct_count,
458                min_value: s.min_value,
459                max_value: s.max_value,
460                sum_value: s
461                    .sum_value
462                    .get_value()
463                    // Cast the row count into the same type as any existing sum value
464                    .and_then(|v| {
465                        Precision::<ScalarValue>::from(left_row_count)
466                            .cast_to(&v.data_type())
467                            .ok()
468                    })
469                    .map(|row_count| s.sum_value.multiply(&row_count))
470                    .unwrap_or(Precision::Absent),
471                byte_size: Precision::Absent,
472            }
473        }))
474        .collect();
475
476    Statistics {
477        num_rows,
478        total_byte_size,
479        column_statistics: cross_join_stats,
480    }
481}
482
483/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
484struct CrossJoinStream<T> {
485    /// Input schema
486    schema: Arc<Schema>,
487    /// Future for data from left side
488    left_fut: OnceFut<JoinLeftData>,
489    /// Right side stream
490    right: SendableRecordBatchStream,
491    /// Current value on the left
492    left_index: usize,
493    /// Join execution metrics
494    join_metrics: BuildProbeJoinMetrics,
495    /// State of the stream
496    state: CrossJoinStreamState,
497    /// Left data (copy of the entire buffered left side)
498    left_data: RecordBatch,
499    /// Batch transformer
500    batch_transformer: T,
501}
502
503impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
504    fn schema(&self) -> SchemaRef {
505        Arc::clone(&self.schema)
506    }
507}
508
509/// Represents states of CrossJoinStream
510enum CrossJoinStreamState {
511    WaitBuildSide,
512    FetchProbeBatch,
513    /// Holds the currently processed right side batch
514    BuildBatches(RecordBatch),
515}
516
517impl CrossJoinStreamState {
518    /// Tries to extract RecordBatch from CrossJoinStreamState enum.
519    /// Returns an error if state is not BuildBatches state.
520    fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
521        match self {
522            CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
523            _ => internal_err!("Expected RecordBatch in BuildBatches state"),
524        }
525    }
526}
527
528fn build_batch(
529    left_index: usize,
530    batch: &RecordBatch,
531    left_data: &RecordBatch,
532    schema: &Schema,
533) -> Result<RecordBatch> {
534    // Repeat value on the left n times
535    let arrays = left_data
536        .columns()
537        .iter()
538        .map(|arr| {
539            let scalar = ScalarValue::try_from_array(arr, left_index)?;
540            scalar.to_array_of_size(batch.num_rows())
541        })
542        .collect::<Result<Vec<_>>>()?;
543
544    RecordBatch::try_new_with_options(
545        Arc::new(schema.clone()),
546        arrays
547            .iter()
548            .chain(batch.columns().iter())
549            .cloned()
550            .collect(),
551        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
552    )
553    .map_err(Into::into)
554}
555
556#[async_trait]
557impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
558    type Item = Result<RecordBatch>;
559
560    fn poll_next(
561        mut self: std::pin::Pin<&mut Self>,
562        cx: &mut std::task::Context<'_>,
563    ) -> Poll<Option<Self::Item>> {
564        self.poll_next_impl(cx)
565    }
566}
567
568impl<T: BatchTransformer> CrossJoinStream<T> {
569    /// Separate implementation function that unpins the [`CrossJoinStream`] so
570    /// that partial borrows work correctly
571    fn poll_next_impl(
572        &mut self,
573        cx: &mut std::task::Context<'_>,
574    ) -> Poll<Option<Result<RecordBatch>>> {
575        loop {
576            return match self.state {
577                CrossJoinStreamState::WaitBuildSide => {
578                    handle_state!(ready!(self.collect_build_side(cx)))
579                }
580                CrossJoinStreamState::FetchProbeBatch => {
581                    handle_state!(ready!(self.fetch_probe_batch(cx)))
582                }
583                CrossJoinStreamState::BuildBatches(_) => {
584                    let poll = handle_state!(self.build_batches());
585                    self.join_metrics.baseline.record_poll(poll)
586                }
587            };
588        }
589    }
590
591    /// Collects build (left) side of the join into the state. In case of an empty build batch,
592    /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
593    fn collect_build_side(
594        &mut self,
595        cx: &mut std::task::Context<'_>,
596    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
597        let build_timer = self.join_metrics.build_time.timer();
598        let left_data = match ready!(self.left_fut.get(cx)) {
599            Ok(left_data) => left_data,
600            Err(e) => return Poll::Ready(Err(e)),
601        };
602        build_timer.done();
603
604        let left_data = left_data.merged_batch.clone();
605        let result = if left_data.num_rows() == 0 {
606            StatefulStreamResult::Ready(None)
607        } else {
608            self.left_data = left_data;
609            self.state = CrossJoinStreamState::FetchProbeBatch;
610            StatefulStreamResult::Continue
611        };
612        Poll::Ready(Ok(result))
613    }
614
615    /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
616    /// Then, the state is updated to build result batches.
617    fn fetch_probe_batch(
618        &mut self,
619        cx: &mut std::task::Context<'_>,
620    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
621        self.left_index = 0;
622        let right_data = match ready!(self.right.poll_next_unpin(cx)) {
623            Some(Ok(right_data)) => right_data,
624            Some(Err(e)) => return Poll::Ready(Err(e)),
625            None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
626        };
627        self.join_metrics.input_batches.add(1);
628        self.join_metrics.input_rows.add(right_data.num_rows());
629
630        self.state = CrossJoinStreamState::BuildBatches(right_data);
631        Poll::Ready(Ok(StatefulStreamResult::Continue))
632    }
633
634    /// Joins the indexed row of left data with the current probe batch.
635    /// If all the results are produced, the state is set to fetch new probe batch.
636    fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
637        let right_batch = self.state.try_as_record_batch()?;
638        if self.left_index < self.left_data.num_rows() {
639            match self.batch_transformer.next() {
640                None => {
641                    let join_timer = self.join_metrics.join_time.timer();
642                    let result = build_batch(
643                        self.left_index,
644                        right_batch,
645                        &self.left_data,
646                        &self.schema,
647                    );
648                    join_timer.done();
649
650                    self.batch_transformer.set_batch(result?);
651                }
652                Some((batch, last)) => {
653                    if last {
654                        self.left_index += 1;
655                    }
656
657                    return Ok(StatefulStreamResult::Ready(Some(batch)));
658                }
659            }
660        } else {
661            self.state = CrossJoinStreamState::FetchProbeBatch;
662        }
663        Ok(StatefulStreamResult::Continue)
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670    use crate::common;
671    use crate::test::{assert_join_metrics, build_table_scan_i32};
672
673    use datafusion_common::{assert_contains, test_util::batches_to_sort_string};
674    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
675    use insta::assert_snapshot;
676
677    async fn join_collect(
678        left: Arc<dyn ExecutionPlan>,
679        right: Arc<dyn ExecutionPlan>,
680        context: Arc<TaskContext>,
681    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
682        let join = CrossJoinExec::new(left, right);
683        let columns_header = columns(&join.schema());
684
685        let stream = join.execute(0, context)?;
686        let batches = common::collect(stream).await?;
687        let metrics = join.metrics().unwrap();
688
689        Ok((columns_header, batches, metrics))
690    }
691
692    #[tokio::test]
693    async fn test_stats_cartesian_product() {
694        let left_row_count = 11;
695        let left_bytes = 23;
696        let right_row_count = 7;
697        let right_bytes = 27;
698
699        let left = Statistics {
700            num_rows: Precision::Exact(left_row_count),
701            total_byte_size: Precision::Exact(left_bytes),
702            column_statistics: vec![
703                ColumnStatistics {
704                    distinct_count: Precision::Exact(5),
705                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
706                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
707                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
708                    null_count: Precision::Exact(0),
709                    byte_size: Precision::Absent,
710                },
711                ColumnStatistics {
712                    distinct_count: Precision::Exact(1),
713                    max_value: Precision::Exact(ScalarValue::from("x")),
714                    min_value: Precision::Exact(ScalarValue::from("a")),
715                    sum_value: Precision::Absent,
716                    null_count: Precision::Exact(3),
717                    byte_size: Precision::Absent,
718                },
719            ],
720        };
721
722        let right = Statistics {
723            num_rows: Precision::Exact(right_row_count),
724            total_byte_size: Precision::Exact(right_bytes),
725            column_statistics: vec![ColumnStatistics {
726                distinct_count: Precision::Exact(3),
727                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
728                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
729                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
730                null_count: Precision::Exact(2),
731                byte_size: Precision::Absent,
732            }],
733        };
734
735        let result = stats_cartesian_product(left, right);
736
737        let expected = Statistics {
738            num_rows: Precision::Exact(left_row_count * right_row_count),
739            total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
740            column_statistics: vec![
741                ColumnStatistics {
742                    distinct_count: Precision::Exact(5),
743                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
744                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
745                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
746                        42 * right_row_count as i64,
747                    ))),
748                    null_count: Precision::Exact(0),
749                    byte_size: Precision::Absent,
750                },
751                ColumnStatistics {
752                    distinct_count: Precision::Exact(1),
753                    max_value: Precision::Exact(ScalarValue::from("x")),
754                    min_value: Precision::Exact(ScalarValue::from("a")),
755                    sum_value: Precision::Absent,
756                    null_count: Precision::Exact(3 * right_row_count),
757                    byte_size: Precision::Absent,
758                },
759                ColumnStatistics {
760                    distinct_count: Precision::Exact(3),
761                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
762                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
763                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
764                        20 * left_row_count as i64,
765                    ))),
766                    null_count: Precision::Exact(2 * left_row_count),
767                    byte_size: Precision::Absent,
768                },
769            ],
770        };
771
772        assert_eq!(result, expected);
773    }
774
775    #[tokio::test]
776    async fn test_stats_cartesian_product_with_unknown_size() {
777        let left_row_count = 11;
778
779        let left = Statistics {
780            num_rows: Precision::Exact(left_row_count),
781            total_byte_size: Precision::Exact(23),
782            column_statistics: vec![
783                ColumnStatistics {
784                    distinct_count: Precision::Exact(5),
785                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
786                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
787                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
788                    null_count: Precision::Exact(0),
789                    byte_size: Precision::Absent,
790                },
791                ColumnStatistics {
792                    distinct_count: Precision::Exact(1),
793                    max_value: Precision::Exact(ScalarValue::from("x")),
794                    min_value: Precision::Exact(ScalarValue::from("a")),
795                    sum_value: Precision::Absent,
796                    null_count: Precision::Exact(3),
797                    byte_size: Precision::Absent,
798                },
799            ],
800        };
801
802        let right = Statistics {
803            num_rows: Precision::Absent,
804            total_byte_size: Precision::Absent,
805            column_statistics: vec![ColumnStatistics {
806                distinct_count: Precision::Exact(3),
807                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
808                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
809                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
810                null_count: Precision::Exact(2),
811                byte_size: Precision::Absent,
812            }],
813        };
814
815        let result = stats_cartesian_product(left, right);
816
817        let expected = Statistics {
818            num_rows: Precision::Absent,
819            total_byte_size: Precision::Absent,
820            column_statistics: vec![
821                ColumnStatistics {
822                    distinct_count: Precision::Exact(5),
823                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
824                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
825                    sum_value: Precision::Absent, // we don't know the row count on the right
826                    null_count: Precision::Absent, // we don't know the row count on the right
827                    byte_size: Precision::Absent,
828                },
829                ColumnStatistics {
830                    distinct_count: Precision::Exact(1),
831                    max_value: Precision::Exact(ScalarValue::from("x")),
832                    min_value: Precision::Exact(ScalarValue::from("a")),
833                    sum_value: Precision::Absent,
834                    null_count: Precision::Absent, // we don't know the row count on the right
835                    byte_size: Precision::Absent,
836                },
837                ColumnStatistics {
838                    distinct_count: Precision::Exact(3),
839                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
840                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
841                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
842                        20 * left_row_count as i64,
843                    ))),
844                    null_count: Precision::Exact(2 * left_row_count),
845                    byte_size: Precision::Absent,
846                },
847            ],
848        };
849
850        assert_eq!(result, expected);
851    }
852
853    #[tokio::test]
854    async fn test_join() -> Result<()> {
855        let task_ctx = Arc::new(TaskContext::default());
856
857        let left = build_table_scan_i32(
858            ("a1", &vec![1, 2, 3]),
859            ("b1", &vec![4, 5, 6]),
860            ("c1", &vec![7, 8, 9]),
861        );
862        let right = build_table_scan_i32(
863            ("a2", &vec![10, 11]),
864            ("b2", &vec![12, 13]),
865            ("c2", &vec![14, 15]),
866        );
867
868        let (columns, batches, metrics) = join_collect(left, right, task_ctx).await?;
869
870        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
871
872        assert_snapshot!(batches_to_sort_string(&batches), @r"
873        +----+----+----+----+----+----+
874        | a1 | b1 | c1 | a2 | b2 | c2 |
875        +----+----+----+----+----+----+
876        | 1  | 4  | 7  | 10 | 12 | 14 |
877        | 1  | 4  | 7  | 11 | 13 | 15 |
878        | 2  | 5  | 8  | 10 | 12 | 14 |
879        | 2  | 5  | 8  | 11 | 13 | 15 |
880        | 3  | 6  | 9  | 10 | 12 | 14 |
881        | 3  | 6  | 9  | 11 | 13 | 15 |
882        +----+----+----+----+----+----+
883        ");
884
885        assert_join_metrics!(metrics, 6);
886
887        Ok(())
888    }
889
890    #[tokio::test]
891    async fn test_overallocation() -> Result<()> {
892        let runtime = RuntimeEnvBuilder::new()
893            .with_memory_limit(100, 1.0)
894            .build_arc()?;
895        let task_ctx = TaskContext::default().with_runtime(runtime);
896        let task_ctx = Arc::new(task_ctx);
897
898        let left = build_table_scan_i32(
899            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
900            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
901            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
902        );
903        let right = build_table_scan_i32(
904            ("a2", &vec![10, 11]),
905            ("b2", &vec![12, 13]),
906            ("c2", &vec![14, 15]),
907        );
908
909        let err = join_collect(left, right, task_ctx).await.unwrap_err();
910
911        assert_contains!(
912            err.to_string(),
913            "Resources exhausted: Additional allocation failed for CrossJoinExec with top memory consumers (across reservations) as:\n  CrossJoinExec"
914        );
915
916        Ok(())
917    }
918
919    /// Returns the column names on the schema
920    fn columns(schema: &Schema) -> Vec<String> {
921        schema.fields().iter().map(|f| f.name().clone()).collect()
922    }
923}