datafusion_physical_plan/joins/
cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the cross join plan for loading the left side of the cross join
19//! and producing batches in parallel for the right partitions
20
21use std::{any::Any, sync::Arc, task::Poll};
22
23use super::utils::{
24    adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter,
25    BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
26    StatefulStreamResult,
27};
28use crate::execution_plan::{boundedness_from_children, EmissionType};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::{
31    join_allows_pushdown, join_table_borders, new_join_children,
32    physical_to_column_exprs, ProjectionExec,
33};
34use crate::{
35    handle_state, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution,
36    ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
37    SendableRecordBatchStream, Statistics,
38};
39
40use arrow::array::{RecordBatch, RecordBatchOptions};
41use arrow::compute::concat_batches;
42use arrow::datatypes::{Fields, Schema, SchemaRef};
43use datafusion_common::stats::Precision;
44use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
45use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
46use datafusion_execution::TaskContext;
47use datafusion_physical_expr::equivalence::join_equivalence_properties;
48
49use async_trait::async_trait;
50use futures::{ready, Stream, StreamExt, TryStreamExt};
51
52/// Data of the left side that is buffered into memory
53#[derive(Debug)]
54struct JoinLeftData {
55    /// Single RecordBatch with all rows from the left side
56    merged_batch: RecordBatch,
57    /// Track memory reservation for merged_batch. Relies on drop
58    /// semantics to release reservation when JoinLeftData is dropped.
59    _reservation: MemoryReservation,
60}
61
62#[allow(rustdoc::private_intra_doc_links)]
63/// Cross Join Execution Plan
64///
65/// This operator is used when there are no predicates between two tables and
66/// returns the Cartesian product of the two tables.
67///
68/// Buffers the left input into memory and then streams batches from each
69/// partition on the right input combining them with the buffered left input
70/// to generate the output.
71///
72/// # Clone / Shared State
73///
74/// Note this structure includes a [`OnceAsync`] that is used to coordinate the
75/// loading of the left side with the processing in each output stream.
76/// Therefore it can not be [`Clone`]
77#[derive(Debug)]
78pub struct CrossJoinExec {
79    /// left (build) side which gets loaded in memory
80    pub left: Arc<dyn ExecutionPlan>,
81    /// right (probe) side which are combined with left side
82    pub right: Arc<dyn ExecutionPlan>,
83    /// The schema once the join is applied
84    schema: SchemaRef,
85    /// Buffered copy of left (build) side in memory.
86    ///
87    /// This structure is *shared* across all output streams.
88    ///
89    /// Each output stream waits on the `OnceAsync` to signal the completion of
90    /// the left side loading.
91    left_fut: OnceAsync<JoinLeftData>,
92    /// Execution plan metrics
93    metrics: ExecutionPlanMetricsSet,
94    /// Properties such as schema, equivalence properties, ordering, partitioning, etc.
95    cache: PlanProperties,
96}
97
98impl CrossJoinExec {
99    /// Create a new [CrossJoinExec].
100    pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
101        // left then right
102        let (all_columns, metadata) = {
103            let left_schema = left.schema();
104            let right_schema = right.schema();
105            let left_fields = left_schema.fields().iter();
106            let right_fields = right_schema.fields().iter();
107
108            let mut metadata = left_schema.metadata().clone();
109            metadata.extend(right_schema.metadata().clone());
110
111            (
112                left_fields.chain(right_fields).cloned().collect::<Fields>(),
113                metadata,
114            )
115        };
116
117        let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
118        let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap();
119
120        CrossJoinExec {
121            left,
122            right,
123            schema,
124            left_fut: Default::default(),
125            metrics: ExecutionPlanMetricsSet::default(),
126            cache,
127        }
128    }
129
130    /// left (build) side which gets loaded in memory
131    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
132        &self.left
133    }
134
135    /// right side which gets combined with left side
136    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
137        &self.right
138    }
139
140    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
141    fn compute_properties(
142        left: &Arc<dyn ExecutionPlan>,
143        right: &Arc<dyn ExecutionPlan>,
144        schema: SchemaRef,
145    ) -> Result<PlanProperties> {
146        // Calculate equivalence properties
147        // TODO: Check equivalence properties of cross join, it may preserve
148        //       ordering in some cases.
149        let eq_properties = join_equivalence_properties(
150            left.equivalence_properties().clone(),
151            right.equivalence_properties().clone(),
152            &JoinType::Full,
153            schema,
154            &[false, false],
155            None,
156            &[],
157        )?;
158
159        // Get output partitioning:
160        // TODO: Optimize the cross join implementation to generate M * N
161        //       partitions.
162        let output_partitioning = adjust_right_output_partitioning(
163            right.output_partitioning(),
164            left.schema().fields.len(),
165        )?;
166
167        Ok(PlanProperties::new(
168            eq_properties,
169            output_partitioning,
170            EmissionType::Final,
171            boundedness_from_children([left, right]),
172        ))
173    }
174
175    /// Returns a new `ExecutionPlan` that computes the same join as this one,
176    /// with the left and right inputs swapped using the  specified
177    /// `partition_mode`.
178    ///
179    /// # Notes:
180    ///
181    /// This function should be called BEFORE inserting any repartitioning
182    /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`]
183    /// for more details.
184    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
185        let new_join =
186            CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left));
187        reorder_output_after_swap(
188            Arc::new(new_join),
189            &self.left.schema(),
190            &self.right.schema(),
191        )
192    }
193}
194
195/// Asynchronously collect the result of the left child
196async fn load_left_input(
197    stream: SendableRecordBatchStream,
198    metrics: BuildProbeJoinMetrics,
199    reservation: MemoryReservation,
200) -> Result<JoinLeftData> {
201    let left_schema = stream.schema();
202
203    // Load all batches and count the rows
204    let (batches, _metrics, reservation) = stream
205        .try_fold(
206            (Vec::new(), metrics, reservation),
207            |(mut batches, metrics, mut reservation), batch| async {
208                let batch_size = batch.get_array_memory_size();
209                // Reserve memory for incoming batch
210                reservation.try_grow(batch_size)?;
211                // Update metrics
212                metrics.build_mem_used.add(batch_size);
213                metrics.build_input_batches.add(1);
214                metrics.build_input_rows.add(batch.num_rows());
215                // Push batch to output
216                batches.push(batch);
217                Ok((batches, metrics, reservation))
218            },
219        )
220        .await?;
221
222    let merged_batch = concat_batches(&left_schema, &batches)?;
223
224    Ok(JoinLeftData {
225        merged_batch,
226        _reservation: reservation,
227    })
228}
229
230impl DisplayAs for CrossJoinExec {
231    fn fmt_as(
232        &self,
233        t: DisplayFormatType,
234        f: &mut std::fmt::Formatter,
235    ) -> std::fmt::Result {
236        match t {
237            DisplayFormatType::Default | DisplayFormatType::Verbose => {
238                write!(f, "CrossJoinExec")
239            }
240            DisplayFormatType::TreeRender => {
241                // no extra info to display
242                Ok(())
243            }
244        }
245    }
246}
247
248impl ExecutionPlan for CrossJoinExec {
249    fn name(&self) -> &'static str {
250        "CrossJoinExec"
251    }
252
253    fn as_any(&self) -> &dyn Any {
254        self
255    }
256
257    fn properties(&self) -> &PlanProperties {
258        &self.cache
259    }
260
261    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
262        vec![&self.left, &self.right]
263    }
264
265    fn metrics(&self) -> Option<MetricsSet> {
266        Some(self.metrics.clone_inner())
267    }
268
269    fn with_new_children(
270        self: Arc<Self>,
271        children: Vec<Arc<dyn ExecutionPlan>>,
272    ) -> Result<Arc<dyn ExecutionPlan>> {
273        Ok(Arc::new(CrossJoinExec::new(
274            Arc::clone(&children[0]),
275            Arc::clone(&children[1]),
276        )))
277    }
278
279    fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
280        let new_exec = CrossJoinExec {
281            left: Arc::clone(&self.left),
282            right: Arc::clone(&self.right),
283            schema: Arc::clone(&self.schema),
284            left_fut: Default::default(), // reset the build side!
285            metrics: ExecutionPlanMetricsSet::default(),
286            cache: self.cache.clone(),
287        };
288        Ok(Arc::new(new_exec))
289    }
290
291    fn required_input_distribution(&self) -> Vec<Distribution> {
292        vec![
293            Distribution::SinglePartition,
294            Distribution::UnspecifiedDistribution,
295        ]
296    }
297
298    fn execute(
299        &self,
300        partition: usize,
301        context: Arc<TaskContext>,
302    ) -> Result<SendableRecordBatchStream> {
303        if self.left.output_partitioning().partition_count() != 1 {
304            return internal_err!(
305                "Invalid CrossJoinExec, the output partition count of the left child must be 1,\
306                 consider using CoalescePartitionsExec or the EnforceDistribution rule"
307            );
308        }
309
310        let stream = self.right.execute(partition, Arc::clone(&context))?;
311
312        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
313
314        // Initialization of operator-level reservation
315        let reservation =
316            MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
317
318        let batch_size = context.session_config().batch_size();
319        let enforce_batch_size_in_joins =
320            context.session_config().enforce_batch_size_in_joins();
321
322        let left_fut = self.left_fut.try_once(|| {
323            let left_stream = self.left.execute(0, context)?;
324
325            Ok(load_left_input(
326                left_stream,
327                join_metrics.clone(),
328                reservation,
329            ))
330        })?;
331
332        if enforce_batch_size_in_joins {
333            Ok(Box::pin(CrossJoinStream {
334                schema: Arc::clone(&self.schema),
335                left_fut,
336                right: stream,
337                left_index: 0,
338                join_metrics,
339                state: CrossJoinStreamState::WaitBuildSide,
340                left_data: RecordBatch::new_empty(self.left().schema()),
341                batch_transformer: BatchSplitter::new(batch_size),
342            }))
343        } else {
344            Ok(Box::pin(CrossJoinStream {
345                schema: Arc::clone(&self.schema),
346                left_fut,
347                right: stream,
348                left_index: 0,
349                join_metrics,
350                state: CrossJoinStreamState::WaitBuildSide,
351                left_data: RecordBatch::new_empty(self.left().schema()),
352                batch_transformer: NoopBatchTransformer::new(),
353            }))
354        }
355    }
356
357    fn statistics(&self) -> Result<Statistics> {
358        self.partition_statistics(None)
359    }
360
361    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
362        // Get the all partitions statistics of the left
363        let left_stats = self.left.partition_statistics(None)?;
364        let right_stats = self.right.partition_statistics(partition)?;
365
366        Ok(stats_cartesian_product(left_stats, right_stats))
367    }
368
369    /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done,
370    /// it returns the new swapped version having the [`CrossJoinExec`] as the top plan.
371    /// Otherwise, it returns None.
372    fn try_swapping_with_projection(
373        &self,
374        projection: &ProjectionExec,
375    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
376        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
377        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
378        else {
379            return Ok(None);
380        };
381
382        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
383            self.left().schema().fields().len(),
384            &projection_as_columns,
385        );
386
387        if !join_allows_pushdown(
388            &projection_as_columns,
389            &self.schema(),
390            far_right_left_col_ind,
391            far_left_right_col_ind,
392        ) {
393            return Ok(None);
394        }
395
396        let (new_left, new_right) = new_join_children(
397            &projection_as_columns,
398            far_right_left_col_ind,
399            far_left_right_col_ind,
400            self.left(),
401            self.right(),
402        )?;
403
404        Ok(Some(Arc::new(CrossJoinExec::new(
405            Arc::new(new_left),
406            Arc::new(new_right),
407        ))))
408    }
409}
410
411/// [left/right]_col_count are required in case the column statistics are None
412fn stats_cartesian_product(
413    left_stats: Statistics,
414    right_stats: Statistics,
415) -> Statistics {
416    let left_row_count = left_stats.num_rows;
417    let right_row_count = right_stats.num_rows;
418
419    // calculate global stats
420    let num_rows = left_row_count.multiply(&right_row_count);
421    // the result size is two times a*b because you have the columns of both left and right
422    let total_byte_size = left_stats
423        .total_byte_size
424        .multiply(&right_stats.total_byte_size)
425        .multiply(&Precision::Exact(2));
426
427    let left_col_stats = left_stats.column_statistics;
428    let right_col_stats = right_stats.column_statistics;
429
430    // the null counts must be multiplied by the row counts of the other side (if defined)
431    // Min, max and distinct_count on the other hand are invariants.
432    let cross_join_stats = left_col_stats
433        .into_iter()
434        .map(|s| ColumnStatistics {
435            null_count: s.null_count.multiply(&right_row_count),
436            distinct_count: s.distinct_count,
437            min_value: s.min_value,
438            max_value: s.max_value,
439            sum_value: s
440                .sum_value
441                .get_value()
442                // Cast the row count into the same type as any existing sum value
443                .and_then(|v| {
444                    Precision::<ScalarValue>::from(right_row_count)
445                        .cast_to(&v.data_type())
446                        .ok()
447                })
448                .map(|row_count| s.sum_value.multiply(&row_count))
449                .unwrap_or(Precision::Absent),
450        })
451        .chain(right_col_stats.into_iter().map(|s| {
452            ColumnStatistics {
453                null_count: s.null_count.multiply(&left_row_count),
454                distinct_count: s.distinct_count,
455                min_value: s.min_value,
456                max_value: s.max_value,
457                sum_value: s
458                    .sum_value
459                    .get_value()
460                    // Cast the row count into the same type as any existing sum value
461                    .and_then(|v| {
462                        Precision::<ScalarValue>::from(left_row_count)
463                            .cast_to(&v.data_type())
464                            .ok()
465                    })
466                    .map(|row_count| s.sum_value.multiply(&row_count))
467                    .unwrap_or(Precision::Absent),
468            }
469        }))
470        .collect();
471
472    Statistics {
473        num_rows,
474        total_byte_size,
475        column_statistics: cross_join_stats,
476    }
477}
478
479/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
480struct CrossJoinStream<T> {
481    /// Input schema
482    schema: Arc<Schema>,
483    /// Future for data from left side
484    left_fut: OnceFut<JoinLeftData>,
485    /// Right side stream
486    right: SendableRecordBatchStream,
487    /// Current value on the left
488    left_index: usize,
489    /// Join execution metrics
490    join_metrics: BuildProbeJoinMetrics,
491    /// State of the stream
492    state: CrossJoinStreamState,
493    /// Left data (copy of the entire buffered left side)
494    left_data: RecordBatch,
495    /// Batch transformer
496    batch_transformer: T,
497}
498
499impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
500    fn schema(&self) -> SchemaRef {
501        Arc::clone(&self.schema)
502    }
503}
504
505/// Represents states of CrossJoinStream
506enum CrossJoinStreamState {
507    WaitBuildSide,
508    FetchProbeBatch,
509    /// Holds the currently processed right side batch
510    BuildBatches(RecordBatch),
511}
512
513impl CrossJoinStreamState {
514    /// Tries to extract RecordBatch from CrossJoinStreamState enum.
515    /// Returns an error if state is not BuildBatches state.
516    fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
517        match self {
518            CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
519            _ => internal_err!("Expected RecordBatch in BuildBatches state"),
520        }
521    }
522}
523
524fn build_batch(
525    left_index: usize,
526    batch: &RecordBatch,
527    left_data: &RecordBatch,
528    schema: &Schema,
529) -> Result<RecordBatch> {
530    // Repeat value on the left n times
531    let arrays = left_data
532        .columns()
533        .iter()
534        .map(|arr| {
535            let scalar = ScalarValue::try_from_array(arr, left_index)?;
536            scalar.to_array_of_size(batch.num_rows())
537        })
538        .collect::<Result<Vec<_>>>()?;
539
540    RecordBatch::try_new_with_options(
541        Arc::new(schema.clone()),
542        arrays
543            .iter()
544            .chain(batch.columns().iter())
545            .cloned()
546            .collect(),
547        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
548    )
549    .map_err(Into::into)
550}
551
552#[async_trait]
553impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
554    type Item = Result<RecordBatch>;
555
556    fn poll_next(
557        mut self: std::pin::Pin<&mut Self>,
558        cx: &mut std::task::Context<'_>,
559    ) -> Poll<Option<Self::Item>> {
560        self.poll_next_impl(cx)
561    }
562}
563
564impl<T: BatchTransformer> CrossJoinStream<T> {
565    /// Separate implementation function that unpins the [`CrossJoinStream`] so
566    /// that partial borrows work correctly
567    fn poll_next_impl(
568        &mut self,
569        cx: &mut std::task::Context<'_>,
570    ) -> Poll<Option<Result<RecordBatch>>> {
571        loop {
572            return match self.state {
573                CrossJoinStreamState::WaitBuildSide => {
574                    handle_state!(ready!(self.collect_build_side(cx)))
575                }
576                CrossJoinStreamState::FetchProbeBatch => {
577                    handle_state!(ready!(self.fetch_probe_batch(cx)))
578                }
579                CrossJoinStreamState::BuildBatches(_) => {
580                    let poll = handle_state!(self.build_batches());
581                    self.join_metrics.baseline.record_poll(poll)
582                }
583            };
584        }
585    }
586
587    /// Collects build (left) side of the join into the state. In case of an empty build batch,
588    /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
589    fn collect_build_side(
590        &mut self,
591        cx: &mut std::task::Context<'_>,
592    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
593        let build_timer = self.join_metrics.build_time.timer();
594        let left_data = match ready!(self.left_fut.get(cx)) {
595            Ok(left_data) => left_data,
596            Err(e) => return Poll::Ready(Err(e)),
597        };
598        build_timer.done();
599
600        let left_data = left_data.merged_batch.clone();
601        let result = if left_data.num_rows() == 0 {
602            StatefulStreamResult::Ready(None)
603        } else {
604            self.left_data = left_data;
605            self.state = CrossJoinStreamState::FetchProbeBatch;
606            StatefulStreamResult::Continue
607        };
608        Poll::Ready(Ok(result))
609    }
610
611    /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
612    /// Then, the state is updated to build result batches.
613    fn fetch_probe_batch(
614        &mut self,
615        cx: &mut std::task::Context<'_>,
616    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
617        self.left_index = 0;
618        let right_data = match ready!(self.right.poll_next_unpin(cx)) {
619            Some(Ok(right_data)) => right_data,
620            Some(Err(e)) => return Poll::Ready(Err(e)),
621            None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
622        };
623        self.join_metrics.input_batches.add(1);
624        self.join_metrics.input_rows.add(right_data.num_rows());
625
626        self.state = CrossJoinStreamState::BuildBatches(right_data);
627        Poll::Ready(Ok(StatefulStreamResult::Continue))
628    }
629
630    /// Joins the indexed row of left data with the current probe batch.
631    /// If all the results are produced, the state is set to fetch new probe batch.
632    fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
633        let right_batch = self.state.try_as_record_batch()?;
634        if self.left_index < self.left_data.num_rows() {
635            match self.batch_transformer.next() {
636                None => {
637                    let join_timer = self.join_metrics.join_time.timer();
638                    let result = build_batch(
639                        self.left_index,
640                        right_batch,
641                        &self.left_data,
642                        &self.schema,
643                    );
644                    join_timer.done();
645
646                    self.batch_transformer.set_batch(result?);
647                }
648                Some((batch, last)) => {
649                    if last {
650                        self.left_index += 1;
651                    }
652
653                    self.join_metrics.output_batches.add(1);
654                    return Ok(StatefulStreamResult::Ready(Some(batch)));
655                }
656            }
657        } else {
658            self.state = CrossJoinStreamState::FetchProbeBatch;
659        }
660        Ok(StatefulStreamResult::Continue)
661    }
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667    use crate::common;
668    use crate::test::{assert_join_metrics, build_table_scan_i32};
669
670    use datafusion_common::{assert_contains, test_util::batches_to_sort_string};
671    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
672    use insta::assert_snapshot;
673
674    async fn join_collect(
675        left: Arc<dyn ExecutionPlan>,
676        right: Arc<dyn ExecutionPlan>,
677        context: Arc<TaskContext>,
678    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
679        let join = CrossJoinExec::new(left, right);
680        let columns_header = columns(&join.schema());
681
682        let stream = join.execute(0, context)?;
683        let batches = common::collect(stream).await?;
684        let metrics = join.metrics().unwrap();
685
686        Ok((columns_header, batches, metrics))
687    }
688
689    #[tokio::test]
690    async fn test_stats_cartesian_product() {
691        let left_row_count = 11;
692        let left_bytes = 23;
693        let right_row_count = 7;
694        let right_bytes = 27;
695
696        let left = Statistics {
697            num_rows: Precision::Exact(left_row_count),
698            total_byte_size: Precision::Exact(left_bytes),
699            column_statistics: vec![
700                ColumnStatistics {
701                    distinct_count: Precision::Exact(5),
702                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
703                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
704                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
705                    null_count: Precision::Exact(0),
706                },
707                ColumnStatistics {
708                    distinct_count: Precision::Exact(1),
709                    max_value: Precision::Exact(ScalarValue::from("x")),
710                    min_value: Precision::Exact(ScalarValue::from("a")),
711                    sum_value: Precision::Absent,
712                    null_count: Precision::Exact(3),
713                },
714            ],
715        };
716
717        let right = Statistics {
718            num_rows: Precision::Exact(right_row_count),
719            total_byte_size: Precision::Exact(right_bytes),
720            column_statistics: vec![ColumnStatistics {
721                distinct_count: Precision::Exact(3),
722                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
723                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
724                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
725                null_count: Precision::Exact(2),
726            }],
727        };
728
729        let result = stats_cartesian_product(left, right);
730
731        let expected = Statistics {
732            num_rows: Precision::Exact(left_row_count * right_row_count),
733            total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
734            column_statistics: vec![
735                ColumnStatistics {
736                    distinct_count: Precision::Exact(5),
737                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
738                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
739                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
740                        42 * right_row_count as i64,
741                    ))),
742                    null_count: Precision::Exact(0),
743                },
744                ColumnStatistics {
745                    distinct_count: Precision::Exact(1),
746                    max_value: Precision::Exact(ScalarValue::from("x")),
747                    min_value: Precision::Exact(ScalarValue::from("a")),
748                    sum_value: Precision::Absent,
749                    null_count: Precision::Exact(3 * right_row_count),
750                },
751                ColumnStatistics {
752                    distinct_count: Precision::Exact(3),
753                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
754                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
755                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
756                        20 * left_row_count as i64,
757                    ))),
758                    null_count: Precision::Exact(2 * left_row_count),
759                },
760            ],
761        };
762
763        assert_eq!(result, expected);
764    }
765
766    #[tokio::test]
767    async fn test_stats_cartesian_product_with_unknown_size() {
768        let left_row_count = 11;
769
770        let left = Statistics {
771            num_rows: Precision::Exact(left_row_count),
772            total_byte_size: Precision::Exact(23),
773            column_statistics: vec![
774                ColumnStatistics {
775                    distinct_count: Precision::Exact(5),
776                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
777                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
778                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
779                    null_count: Precision::Exact(0),
780                },
781                ColumnStatistics {
782                    distinct_count: Precision::Exact(1),
783                    max_value: Precision::Exact(ScalarValue::from("x")),
784                    min_value: Precision::Exact(ScalarValue::from("a")),
785                    sum_value: Precision::Absent,
786                    null_count: Precision::Exact(3),
787                },
788            ],
789        };
790
791        let right = Statistics {
792            num_rows: Precision::Absent,
793            total_byte_size: Precision::Absent,
794            column_statistics: vec![ColumnStatistics {
795                distinct_count: Precision::Exact(3),
796                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
797                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
798                sum_value: Precision::Exact(ScalarValue::Int64(Some(20))),
799                null_count: Precision::Exact(2),
800            }],
801        };
802
803        let result = stats_cartesian_product(left, right);
804
805        let expected = Statistics {
806            num_rows: Precision::Absent,
807            total_byte_size: Precision::Absent,
808            column_statistics: vec![
809                ColumnStatistics {
810                    distinct_count: Precision::Exact(5),
811                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
812                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
813                    sum_value: Precision::Absent, // we don't know the row count on the right
814                    null_count: Precision::Absent, // we don't know the row count on the right
815                },
816                ColumnStatistics {
817                    distinct_count: Precision::Exact(1),
818                    max_value: Precision::Exact(ScalarValue::from("x")),
819                    min_value: Precision::Exact(ScalarValue::from("a")),
820                    sum_value: Precision::Absent,
821                    null_count: Precision::Absent, // we don't know the row count on the right
822                },
823                ColumnStatistics {
824                    distinct_count: Precision::Exact(3),
825                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
826                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
827                    sum_value: Precision::Exact(ScalarValue::Int64(Some(
828                        20 * left_row_count as i64,
829                    ))),
830                    null_count: Precision::Exact(2 * left_row_count),
831                },
832            ],
833        };
834
835        assert_eq!(result, expected);
836    }
837
838    #[tokio::test]
839    async fn test_join() -> Result<()> {
840        let task_ctx = Arc::new(TaskContext::default());
841
842        let left = build_table_scan_i32(
843            ("a1", &vec![1, 2, 3]),
844            ("b1", &vec![4, 5, 6]),
845            ("c1", &vec![7, 8, 9]),
846        );
847        let right = build_table_scan_i32(
848            ("a2", &vec![10, 11]),
849            ("b2", &vec![12, 13]),
850            ("c2", &vec![14, 15]),
851        );
852
853        let (columns, batches, metrics) = join_collect(left, right, task_ctx).await?;
854
855        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
856
857        assert_snapshot!(batches_to_sort_string(&batches), @r#"
858            +----+----+----+----+----+----+
859            | a1 | b1 | c1 | a2 | b2 | c2 |
860            +----+----+----+----+----+----+
861            | 1  | 4  | 7  | 10 | 12 | 14 |
862            | 1  | 4  | 7  | 11 | 13 | 15 |
863            | 2  | 5  | 8  | 10 | 12 | 14 |
864            | 2  | 5  | 8  | 11 | 13 | 15 |
865            | 3  | 6  | 9  | 10 | 12 | 14 |
866            | 3  | 6  | 9  | 11 | 13 | 15 |
867            +----+----+----+----+----+----+
868            "#);
869
870        assert_join_metrics!(metrics, 6);
871
872        Ok(())
873    }
874
875    #[tokio::test]
876    async fn test_overallocation() -> Result<()> {
877        let runtime = RuntimeEnvBuilder::new()
878            .with_memory_limit(100, 1.0)
879            .build_arc()?;
880        let task_ctx = TaskContext::default().with_runtime(runtime);
881        let task_ctx = Arc::new(task_ctx);
882
883        let left = build_table_scan_i32(
884            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
885            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
886            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
887        );
888        let right = build_table_scan_i32(
889            ("a2", &vec![10, 11]),
890            ("b2", &vec![12, 13]),
891            ("c2", &vec![14, 15]),
892        );
893
894        let err = join_collect(left, right, task_ctx).await.unwrap_err();
895
896        assert_contains!(
897            err.to_string(),
898            "Resources exhausted: Additional allocation failed for CrossJoinExec with top memory consumers (across reservations) as:\n  CrossJoinExec"
899        );
900
901        Ok(())
902    }
903
904    /// Returns the column names on the schema
905    fn columns(schema: &Schema) -> Vec<String> {
906        schema.fields().iter().map(|f| f.name().clone()).collect()
907    }
908}