Skip to main content

datafusion_physical_plan/joins/
cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the cross join plan for loading the left side of the cross join
19//! and producing batches in parallel for the right partitions
20
21use std::{any::Any, sync::Arc, task::Poll};
22
23use super::utils::{
24    BatchSplitter, BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer,
25    OnceAsync, OnceFut, StatefulStreamResult, adjust_right_output_partitioning,
26    reorder_output_after_swap,
27};
28use crate::execution_plan::{EmissionType, boundedness_from_children};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::{
31    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
32    physical_to_column_exprs,
33};
34use crate::{
35    ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
36    ExecutionPlanProperties, PlanProperties, RecordBatchStream,
37    SendableRecordBatchStream, Statistics, check_if_same_properties, handle_state,
38};
39
40use arrow::array::{RecordBatch, RecordBatchOptions};
41use arrow::compute::concat_batches;
42use arrow::datatypes::{Fields, Schema, SchemaRef};
43use datafusion_common::stats::Precision;
44use datafusion_common::{
45    JoinType, Result, ScalarValue, assert_eq_or_internal_err, internal_err,
46};
47use datafusion_execution::TaskContext;
48use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
49use datafusion_physical_expr::equivalence::join_equivalence_properties;
50
51use async_trait::async_trait;
52use futures::{Stream, StreamExt, TryStreamExt, ready};
53
54/// Data of the left side that is buffered into memory
55#[derive(Debug)]
56struct JoinLeftData {
57    /// Single RecordBatch with all rows from the left side
58    merged_batch: RecordBatch,
59    /// Track memory reservation for merged_batch. Relies on drop
60    /// semantics to release reservation when JoinLeftData is dropped.
61    _reservation: MemoryReservation,
62}
63
64#[expect(rustdoc::private_intra_doc_links)]
65/// Cross Join Execution Plan
66///
67/// This operator is used when there are no predicates between two tables and
68/// returns the Cartesian product of the two tables.
69///
70/// Buffers the left input into memory and then streams batches from each
71/// partition on the right input combining them with the buffered left input
72/// to generate the output.
73///
74/// # Clone / Shared State
75///
76/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
77/// loading of the left side with the processing in each output stream.
78/// Therefore it can not be [`Clone`]
79#[derive(Debug)]
80pub struct CrossJoinExec {
81    /// left (build) side which gets loaded in memory
82    pub left: Arc<dyn ExecutionPlan>,
83    /// right (probe) side which are combined with left side
84    pub right: Arc<dyn ExecutionPlan>,
85    /// The schema once the join is applied
86    schema: SchemaRef,
87    /// Buffered copy of left (build) side in memory.
88    ///
89    /// This structure is *shared* across all output streams.
90    ///
91    /// Each output stream waits on the `OnceAsync` to signal the completion of
92    /// the left side loading.
93    left_fut: OnceAsync<JoinLeftData>,
94    /// Execution plan metrics
95    metrics: ExecutionPlanMetricsSet,
96    /// Properties such as schema, equivalence properties, ordering, partitioning, etc.
97    cache: Arc<PlanProperties>,
98}
99
100impl CrossJoinExec {
101    /// Create a new [CrossJoinExec].
102    pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
103        // left then right
104        let (all_columns, metadata) = {
105            let left_schema = left.schema();
106            let right_schema = right.schema();
107            let left_fields = left_schema.fields().iter();
108            let right_fields = right_schema.fields().iter();
109
110            let mut metadata = left_schema.metadata().clone();
111            metadata.extend(right_schema.metadata().clone());
112
113            (
114                left_fields.chain(right_fields).cloned().collect::<Fields>(),
115                metadata,
116            )
117        };
118
119        let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
120        let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap();
121
122        CrossJoinExec {
123            left,
124            right,
125            schema,
126            left_fut: Default::default(),
127            metrics: ExecutionPlanMetricsSet::default(),
128            cache: Arc::new(cache),
129        }
130    }
131
132    /// left (build) side which gets loaded in memory
133    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
134        &self.left
135    }
136
137    /// right side which gets combined with left side
138    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
139        &self.right
140    }
141
142    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
143    fn compute_properties(
144        left: &Arc<dyn ExecutionPlan>,
145        right: &Arc<dyn ExecutionPlan>,
146        schema: SchemaRef,
147    ) -> Result<PlanProperties> {
148        // Calculate equivalence properties
149        // TODO: Check equivalence properties of cross join, it may preserve
150        //       ordering in some cases.
151        let eq_properties = join_equivalence_properties(
152            left.equivalence_properties().clone(),
153            right.equivalence_properties().clone(),
154            &JoinType::Full,
155            schema,
156            &[false, false],
157            None,
158            &[],
159        )?;
160
161        // Get output partitioning:
162        // TODO: Optimize the cross join implementation to generate M * N
163        //       partitions.
164        let output_partitioning = adjust_right_output_partitioning(
165            right.output_partitioning(),
166            left.schema().fields.len(),
167        )?;
168
169        Ok(PlanProperties::new(
170            eq_properties,
171            output_partitioning,
172            EmissionType::Final,
173            boundedness_from_children([left, right]),
174        ))
175    }
176
177    /// Returns a new `ExecutionPlan` that computes the same join as this one,
178    /// with the left and right inputs swapped using the  specified
179    /// `partition_mode`.
180    ///
181    /// # Notes:
182    ///
183    /// This function should be called BEFORE inserting any repartitioning
184    /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`]
185    /// for more details.
186    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
187        let new_join =
188            CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
189        reorder_output_after_swap(
190            Arc::new(new_join),
191            &self.left.schema(),
192            &self.right.schema(),
193        )
194    }
195
196    fn with_new_children_and_same_properties(
197        &self,
198        mut children: Vec<Arc<dyn ExecutionPlan>>,
199    ) -> Self {
200        let left = children.swap_remove(0);
201        let right = children.swap_remove(0);
202
203        Self {
204            left,
205            right,
206            metrics: ExecutionPlanMetricsSet::new(),
207            left_fut: Default::default(),
208            cache: Arc::clone(&self.cache),
209            schema: Arc::clone(&self.schema),
210        }
211    }
212}
213
214/// Asynchronously collect the result of the left child
215async fn load_left_input(
216    stream: SendableRecordBatchStream,
217    metrics: BuildProbeJoinMetrics,
218    reservation: MemoryReservation,
219) -> Result<JoinLeftData> {
220    let left_schema = stream.schema();
221
222    // Load all batches and count the rows
223    let (batches, _metrics, reservation) = stream
224        .try_fold(
225            (Vec::new(), metrics, reservation),
226            |(mut batches, metrics, reservation), batch| async {
227                let batch_size = batch.get_array_memory_size();
228                // Reserve memory for incoming batch
229                reservation.try_grow(batch_size)?;
230                // Update metrics
231                metrics.build_mem_used.add(batch_size);
232                metrics.build_input_batches.add(1);
233                metrics.build_input_rows.add(batch.num_rows());
234                // Push batch to output
235                batches.push(batch);
236                Ok((batches, metrics, reservation))
237            },
238        )
239        .await?;
240
241    let merged_batch = concat_batches(&left_schema, &batches)?;
242
243    Ok(JoinLeftData {
244        merged_batch,
245        _reservation: reservation,
246    })
247}
248
249impl DisplayAs for CrossJoinExec {
250    fn fmt_as(
251        &self,
252        t: DisplayFormatType,
253        f: &mut std::fmt::Formatter,
254    ) -> std::fmt::Result {
255        match t {
256            DisplayFormatType::Default | DisplayFormatType::Verbose => {
257                write!(f, "CrossJoinExec")
258            }
259            DisplayFormatType::TreeRender => {
260                // no extra info to display
261                Ok(())
262            }
263        }
264    }
265}
266
267impl ExecutionPlan for CrossJoinExec {
268    fn name(&self) -> &'static str {
269        "CrossJoinExec"
270    }
271
272    fn as_any(&self) -> &dyn Any {
273        self
274    }
275
276    fn properties(&self) -> &Arc<PlanProperties> {
277        &self.cache
278    }
279
280    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
281        vec![&self.left, &self.right]
282    }
283
284    fn metrics(&self) -> Option<MetricsSet> {
285        Some(self.metrics.clone_inner())
286    }
287
288    fn with_new_children(
289        self: Arc<Self>,
290        children: Vec<Arc<dyn ExecutionPlan>>,
291    ) -> Result<Arc<dyn ExecutionPlan>> {
292        check_if_same_properties!(self, children);
293        Ok(Arc::new(CrossJoinExec::new(
294            Arc::clone(&children[0]),
295            Arc::clone(&children[1]),
296        )))
297    }
298
299    fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
300        let new_exec = CrossJoinExec {
301            left: Arc::clone(&self.left),
302            right: Arc::clone(&self.right),
303            schema: Arc::clone(&self.schema),
304            left_fut: Default::default(), // reset the build side!
305            metrics: ExecutionPlanMetricsSet::default(),
306            cache: Arc::clone(&self.cache),
307        };
308        Ok(Arc::new(new_exec))
309    }
310
311    fn required_input_distribution(&self) -> Vec<Distribution> {
312        vec![
313            Distribution::SinglePartition,
314            Distribution::UnspecifiedDistribution,
315        ]
316    }
317
318    fn execute(
319        &self,
320        partition: usize,
321        context: Arc<TaskContext>,
322    ) -> Result<SendableRecordBatchStream> {
323        assert_eq_or_internal_err!(
324            self.left.output_partitioning().partition_count(),
325            1,
326            "Invalid CrossJoinExec, the output partition count of the left child must be 1,\
327                 consider using CoalescePartitionsExec or the EnforceDistribution rule"
328        );
329
330        let stream = self.right.execute(partition, Arc::clone(&context))?;
331
332        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
333
334        // Initialization of operator-level reservation
335        let reservation =
336            MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
337
338        let batch_size = context.session_config().batch_size();
339        let enforce_batch_size_in_joins =
340            context.session_config().enforce_batch_size_in_joins();
341
342        let left_fut = self.left_fut.try_once(|| {
343            let left_stream = self.left.execute(0, context)?;
344
345            Ok(load_left_input(
346                left_stream,
347                join_metrics.clone(),
348                reservation,
349            ))
350        })?;
351
352        if enforce_batch_size_in_joins {
353            Ok(Box::pin(CrossJoinStream {
354                schema: Arc::clone(&self.schema),
355                left_fut,
356                right: stream,
357                left_index: 0,
358                join_metrics,
359                state: CrossJoinStreamState::WaitBuildSide,
360                left_data: RecordBatch::new_empty(self.left().schema()),
361                batch_transformer: BatchSplitter::new(batch_size),
362            }))
363        } else {
364            Ok(Box::pin(CrossJoinStream {
365                schema: Arc::clone(&self.schema),
366                left_fut,
367                right: stream,
368                left_index: 0,
369                join_metrics,
370                state: CrossJoinStreamState::WaitBuildSide,
371                left_data: RecordBatch::new_empty(self.left().schema()),
372                batch_transformer: NoopBatchTransformer::new(),
373            }))
374        }
375    }
376
377    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
378        // Get the all partitions statistics of the left
379        let left_stats = self.left.partition_statistics(None)?;
380        let right_stats = self.right.partition_statistics(partition)?;
381
382        Ok(stats_cartesian_product(left_stats, right_stats))
383    }
384
385    /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done,
386    /// it returns the new swapped version having the [`CrossJoinExec`] as the top plan.
387    /// Otherwise, it returns None.
388    fn try_swapping_with_projection(
389        &self,
390        projection: &ProjectionExec,
391    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
392        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
393        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
394        else {
395            return Ok(None);
396        };
397
398        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
399            self.left().schema().fields().len(),
400            &projection_as_columns,
401        );
402
403        if !join_allows_pushdown(
404            &projection_as_columns,
405            &self.schema(),
406            far_right_left_col_ind,
407            far_left_right_col_ind,
408        ) {
409            return Ok(None);
410        }
411
412        let (new_left, new_right) = new_join_children(
413            &projection_as_columns,
414            far_right_left_col_ind,
415            far_left_right_col_ind,
416            self.left(),
417            self.right(),
418        )?;
419
420        Ok(Some(Arc::new(CrossJoinExec::new(
421            Arc::new(new_left),
422            Arc::new(new_right),
423        ))))
424    }
425}
426
427/// [left/right]_col_count are required in case the column statistics are None
428fn stats_cartesian_product(
429    left_stats: Statistics,
430    right_stats: Statistics,
431) -> Statistics {
432    let left_row_count = left_stats.num_rows;
433    let right_row_count = right_stats.num_rows;
434
435    // calculate global stats
436    let num_rows = left_row_count.multiply(&right_row_count);
437    // the result size is two times a*b because you have the columns of both left and right
438    let total_byte_size = left_stats
439        .total_byte_size
440        .multiply(&right_stats.total_byte_size)
441        .multiply(&Precision::Exact(2));
442
443    let left_col_stats = left_stats.column_statistics;
444    let right_col_stats = right_stats.column_statistics;
445
446    // the null counts must be multiplied by the row counts of the other side (if defined)
447    // Min, max and distinct_count on the other hand are invariants.
448    let cross_join_stats = left_col_stats
449        .into_iter()
450        .map(|s| ColumnStatistics {
451            null_count: s.null_count.multiply(&right_row_count),
452            distinct_count: s.distinct_count,
453            min_value: s.min_value,
454            max_value: s.max_value,
455            sum_value: s
456                .sum_value
457                .get_value()
458                // Cast the row count into the same type as any existing sum value
459                .and_then(|v| {
460                    Precision::<ScalarValue>::from(right_row_count)
461                        .cast_to(&v.data_type())
462                        .ok()
463                })
464                .map(|row_count| s.sum_value.multiply(&row_count))
465                .unwrap_or(Precision::Absent),
466            byte_size: Precision::Absent,
467        })
468        .chain(right_col_stats.into_iter().map(|s| {
469            ColumnStatistics {
470                null_count: s.null_count.multiply(&left_row_count),
471                distinct_count: s.distinct_count,
472                min_value: s.min_value,
473                max_value: s.max_value,
474                sum_value: s
475                    .sum_value
476                    .get_value()
477                    // Cast the row count into the same type as any existing sum value
478                    .and_then(|v| {
479                        Precision::<ScalarValue>::from(left_row_count)
480                            .cast_to(&v.data_type())
481                            .ok()
482                    })
483                    .map(|row_count| s.sum_value.multiply(&row_count))
484                    .unwrap_or(Precision::Absent),
485                byte_size: Precision::Absent,
486            }
487        }))
488        .collect();
489
490    Statistics {
491        num_rows,
492        total_byte_size,
493        column_statistics: cross_join_stats,
494    }
495}
496
497/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
498struct CrossJoinStream<T> {
499    /// Input schema
500    schema: Arc<Schema>,
501    /// Future for data from left side
502    left_fut: OnceFut<JoinLeftData>,
503    /// Right side stream
504    right: SendableRecordBatchStream,
505    /// Current value on the left
506    left_index: usize,
507    /// Join execution metrics
508    join_metrics: BuildProbeJoinMetrics,
509    /// State of the stream
510    state: CrossJoinStreamState,
511    /// Left data (copy of the entire buffered left side)
512    left_data: RecordBatch,
513    /// Batch transformer
514    batch_transformer: T,
515}
516
517impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
518    fn schema(&self) -> SchemaRef {
519        Arc::clone(&self.schema)
520    }
521}
522
523/// Represents states of CrossJoinStream
524enum CrossJoinStreamState {
525    WaitBuildSide,
526    FetchProbeBatch,
527    /// Holds the currently processed right side batch
528    BuildBatches(RecordBatch),
529}
530
531impl CrossJoinStreamState {
532    /// Tries to extract RecordBatch from CrossJoinStreamState enum.
533    /// Returns an error if state is not BuildBatches state.
534    fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
535        match self {
536            CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
537            _ => internal_err!("Expected RecordBatch in BuildBatches state"),
538        }
539    }
540}
541
542fn build_batch(
543    left_index: usize,
544    batch: &RecordBatch,
545    left_data: &RecordBatch,
546    schema: &Schema,
547) -> Result<RecordBatch> {
548    // Repeat value on the left n times
549    let arrays = left_data
550        .columns()
551        .iter()
552        .map(|arr| {
553            let scalar = ScalarValue::try_from_array(arr, left_index)?;
554            scalar.to_array_of_size(batch.num_rows())
555        })
556        .collect::<Result<Vec<_>>>()?;
557
558    RecordBatch::try_new_with_options(
559        Arc::new(schema.clone()),
560        arrays
561            .iter()
562            .chain(batch.columns().iter())
563            .cloned()
564            .collect(),
565        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
566    )
567    .map_err(Into::into)
568}
569
570#[async_trait]
571impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
572    type Item = Result<RecordBatch>;
573
574    fn poll_next(
575        mut self: std::pin::Pin<&mut Self>,
576        cx: &mut std::task::Context<'_>,
577    ) -> Poll<Option<Self::Item>> {
578        self.poll_next_impl(cx)
579    }
580}
581
582impl<T: BatchTransformer> CrossJoinStream<T> {
583    /// Separate implementation function that unpins the [`CrossJoinStream`] so
584    /// that partial borrows work correctly
585    fn poll_next_impl(
586        &mut self,
587        cx: &mut std::task::Context<'_>,
588    ) -> Poll<Option<Result<RecordBatch>>> {
589        loop {
590            return match self.state {
591                CrossJoinStreamState::WaitBuildSide => {
592                    handle_state!(ready!(self.collect_build_side(cx)))
593                }
594                CrossJoinStreamState::FetchProbeBatch => {
595                    handle_state!(ready!(self.fetch_probe_batch(cx)))
596                }
597                CrossJoinStreamState::BuildBatches(_) => {
598                    let poll = handle_state!(self.build_batches());
599                    self.join_metrics.baseline.record_poll(poll)
600                }
601            };
602        }
603    }
604
605    /// Collects build (left) side of the join into the state. In case of an empty build batch,
606    /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
607    fn collect_build_side(
608        &mut self,
609        cx: &mut std::task::Context<'_>,
610    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
611        let build_timer = self.join_metrics.build_time.timer();
612        let left_data = match ready!(self.left_fut.get(cx)) {
613            Ok(left_data) => left_data,
614            Err(e) => return Poll::Ready(Err(e)),
615        };
616        build_timer.done();
617
618        let left_data = left_data.merged_batch.clone();
619        let result = if left_data.num_rows() == 0 {
620            StatefulStreamResult::Ready(None)
621        } else {
622            self.left_data = left_data;
623            self.state = CrossJoinStreamState::FetchProbeBatch;
624            StatefulStreamResult::Continue
625        };
626        Poll::Ready(Ok(result))
627    }
628
629    /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
630    /// Then, the state is updated to build result batches.
631    fn fetch_probe_batch(
632        &mut self,
633        cx: &mut std::task::Context<'_>,
634    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
635        self.left_index = 0;
636        let right_data = match ready!(self.right.poll_next_unpin(cx)) {
637            Some(Ok(right_data)) => right_data,
638            Some(Err(e)) => return Poll::Ready(Err(e)),
639            None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
640        };
641        self.join_metrics.input_batches.add(1);
642        self.join_metrics.input_rows.add(right_data.num_rows());
643
644        self.state = CrossJoinStreamState::BuildBatches(right_data);
645        Poll::Ready(Ok(StatefulStreamResult::Continue))
646    }
647
648    /// Joins the indexed row of left data with the current probe batch.
649    /// If all the results are produced, the state is set to fetch new probe batch.
650    fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
651        let right_batch = self.state.try_as_record_batch()?;
652        if self.left_index < self.left_data.num_rows() {
653            match self.batch_transformer.next() {
654                None => {
655                    let join_timer = self.join_metrics.join_time.timer();
656                    let result = build_batch(
657                        self.left_index,
658                        right_batch,
659                        &self.left_data,
660                        &self.schema,
661                    );
662                    join_timer.done();
663
664                    self.batch_transformer.set_batch(result?);
665                }
666                Some((batch, last)) => {
667                    if last {
668                        self.left_index += 1;
669                    }
670
671                    return Ok(StatefulStreamResult::Ready(Some(batch)));
672                }
673            }
674        } else {
675            self.state = CrossJoinStreamState::FetchProbeBatch;
676        }
677        Ok(StatefulStreamResult::Continue)
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use crate::common;
685    use crate::test::{assert_join_metrics, build_table_scan_i32};
686
687    use datafusion_common::{assert_contains, test_util::batches_to_sort_string};
688    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
689    use insta::assert_snapshot;
690
691    async fn join_collect(
692        left: Arc<dyn ExecutionPlan>,
693        right: Arc<dyn ExecutionPlan>,
694        context: Arc<TaskContext>,
695    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
696        let join = CrossJoinExec::new(left, right);
697        let columns_header = columns(&join.schema());
698
699        let stream = join.execute(0, context)?;
700        let batches = common::collect(stream).await?;
701        let metrics = join.metrics().unwrap();
702
703        Ok((columns_header, batches, metrics))
704    }
705
706    #[tokio::test]
707    async fn test_stats_cartesian_product() {
708        let left_row_count = 11;
709        let left_bytes = 23;
710        let right_row_count = 7;
711        let right_bytes = 27;
712
713        let left = Statistics {
714            num_rows: Precision::Exact(left_row_count),
715            total_byte_size: Precision::Exact(left_bytes),
716            column_statistics: vec![
717                ColumnStatistics {
718                    distinct_count: Precision::Exact(5),
719                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
720                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
721                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
722                    null_count: Precision::Exact(0),
723                    byte_size: Precision::Absent,
724                },
725                ColumnStatistics {
726                    distinct_count: Precision::Exact(1),
727                    max_value: Precision::Exact(ScalarValue::from("x")),
728                    min_value: Precision::Exact(ScalarValue::from("a")),
729                    sum_value: Precision::Absent,
730                    null_count: Precision::Exact(3),
731                    byte_size: Precision::Absent,
732                },
733            ],
734        };
735
736        let right = Statistics {
737            num_rows: Precision::Exact(right_row_count),
738            total_byte_size: Precision::Exact(right_bytes),
739            column_statistics: vec![ColumnStatistics {
740                distinct_count: Precision::Exact(3),
741                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
742                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
743                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
744                null_count: Precision::Exact(2),
745                byte_size: Precision::Absent,
746            }],
747        };
748
749        let result = stats_cartesian_product(left, right);
750
751        let expected = Statistics {
752            num_rows: Precision::Exact(left_row_count * right_row_count),
753            total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
754            column_statistics: vec![
755                ColumnStatistics {
756                    distinct_count: Precision::Exact(5),
757                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
758                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
759                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
760                        42 * right_row_count as i64,
761                    ))),
762                    null_count: Precision::Exact(0),
763                    byte_size: Precision::Absent,
764                },
765                ColumnStatistics {
766                    distinct_count: Precision::Exact(1),
767                    max_value: Precision::Exact(ScalarValue::from("x")),
768                    min_value: Precision::Exact(ScalarValue::from("a")),
769                    sum_value: Precision::Absent,
770                    null_count: Precision::Exact(3 * right_row_count),
771                    byte_size: Precision::Absent,
772                },
773                ColumnStatistics {
774                    distinct_count: Precision::Exact(3),
775                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
776                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
777                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
778                        20 * left_row_count as i64,
779                    ))),
780                    null_count: Precision::Exact(2 * left_row_count),
781                    byte_size: Precision::Absent,
782                },
783            ],
784        };
785
786        assert_eq!(result, expected);
787    }
788
789    #[tokio::test]
790    async fn test_stats_cartesian_product_with_unknown_size() {
791        let left_row_count = 11;
792
793        let left = Statistics {
794            num_rows: Precision::Exact(left_row_count),
795            total_byte_size: Precision::Exact(23),
796            column_statistics: vec![
797                ColumnStatistics {
798                    distinct_count: Precision::Exact(5),
799                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
800                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
801                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
802                    null_count: Precision::Exact(0),
803                    byte_size: Precision::Absent,
804                },
805                ColumnStatistics {
806                    distinct_count: Precision::Exact(1),
807                    max_value: Precision::Exact(ScalarValue::from("x")),
808                    min_value: Precision::Exact(ScalarValue::from("a")),
809                    sum_value: Precision::Absent,
810                    null_count: Precision::Exact(3),
811                    byte_size: Precision::Absent,
812                },
813            ],
814        };
815
816        let right = Statistics {
817            num_rows: Precision::Absent,
818            total_byte_size: Precision::Absent,
819            column_statistics: vec![ColumnStatistics {
820                distinct_count: Precision::Exact(3),
821                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
822                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
823                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
824                null_count: Precision::Exact(2),
825                byte_size: Precision::Absent,
826            }],
827        };
828
829        let result = stats_cartesian_product(left, right);
830
831        let expected = Statistics {
832            num_rows: Precision::Absent,
833            total_byte_size: Precision::Absent,
834            column_statistics: vec![
835                ColumnStatistics {
836                    distinct_count: Precision::Exact(5),
837                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
838                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
839                    sum_value: Precision::Absent, // we don't know the row count on the right
840                    null_count: Precision::Absent, // we don't know the row count on the right
841                    byte_size: Precision::Absent,
842                },
843                ColumnStatistics {
844                    distinct_count: Precision::Exact(1),
845                    max_value: Precision::Exact(ScalarValue::from("x")),
846                    min_value: Precision::Exact(ScalarValue::from("a")),
847                    sum_value: Precision::Absent,
848                    null_count: Precision::Absent, // we don't know the row count on the right
849                    byte_size: Precision::Absent,
850                },
851                ColumnStatistics {
852                    distinct_count: Precision::Exact(3),
853                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
854                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
855                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
856                        20 * left_row_count as i64,
857                    ))),
858                    null_count: Precision::Exact(2 * left_row_count),
859                    byte_size: Precision::Absent,
860                },
861            ],
862        };
863
864        assert_eq!(result, expected);
865    }
866
867    #[tokio::test]
868    async fn test_join() -> Result<()> {
869        let task_ctx = Arc::new(TaskContext::default());
870
871        let left = build_table_scan_i32(
872            ("a1", &vec![1, 2, 3]),
873            ("b1", &vec![4, 5, 6]),
874            ("c1", &vec![7, 8, 9]),
875        );
876        let right = build_table_scan_i32(
877            ("a2", &vec![10, 11]),
878            ("b2", &vec![12, 13]),
879            ("c2", &vec![14, 15]),
880        );
881
882        let (columns, batches, metrics) = join_collect(left, right, task_ctx).await?;
883
884        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
885
886        assert_snapshot!(batches_to_sort_string(&batches), @r"
887        +----+----+----+----+----+----+
888        | a1 | b1 | c1 | a2 | b2 | c2 |
889        +----+----+----+----+----+----+
890        | 1  | 4  | 7  | 10 | 12 | 14 |
891        | 1  | 4  | 7  | 11 | 13 | 15 |
892        | 2  | 5  | 8  | 10 | 12 | 14 |
893        | 2  | 5  | 8  | 11 | 13 | 15 |
894        | 3  | 6  | 9  | 10 | 12 | 14 |
895        | 3  | 6  | 9  | 11 | 13 | 15 |
896        +----+----+----+----+----+----+
897        ");
898
899        assert_join_metrics!(metrics, 6);
900
901        Ok(())
902    }
903
904    #[tokio::test]
905    async fn test_overallocation() -> Result<()> {
906        let runtime = RuntimeEnvBuilder::new()
907            .with_memory_limit(100, 1.0)
908            .build_arc()?;
909        let task_ctx = TaskContext::default().with_runtime(runtime);
910        let task_ctx = Arc::new(task_ctx);
911
912        let left = build_table_scan_i32(
913            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
914            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
915            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
916        );
917        let right = build_table_scan_i32(
918            ("a2", &vec![10, 11]),
919            ("b2", &vec![12, 13]),
920            ("c2", &vec![14, 15]),
921        );
922
923        let err = join_collect(left, right, task_ctx).await.unwrap_err();
924
925        assert_contains!(
926            err.to_string(),
927            "Resources exhausted: Additional allocation failed for CrossJoinExec with top memory consumers (across reservations) as:\n  CrossJoinExec"
928        );
929
930        Ok(())
931    }
932
933    /// Returns the column names on the schema
934    fn columns(schema: &Schema) -> Vec<String> {
935        schema.fields().iter().map(|f| f.name().clone()).collect()
936    }
937}