datafusion_physical_plan/joins/sort_merge_join/
exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the Sort-Merge join execution plan.
19//! A Sort-Merge join plan consumes two sorted children plans and produces
20//! joined output by given join type and other options.
21
22use std::any::Any;
23use std::fmt::Formatter;
24use std::sync::Arc;
25
26use crate::execution_plan::{boundedness_from_children, EmissionType};
27use crate::expressions::PhysicalSortExpr;
28use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
29use crate::joins::sort_merge_join::stream::SortMergeJoinStream;
30use crate::joins::utils::{
31    build_join_schema, check_join_is_valid, estimate_join_statistics,
32    reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn,
33    JoinOnRef,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
36use crate::projection::{
37    join_allows_pushdown, join_table_borders, new_join_children,
38    physical_to_column_exprs, update_join_on, ProjectionExec,
39};
40use crate::{
41    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
42    PlanProperties, SendableRecordBatchStream, Statistics,
43};
44
45use arrow::compute::SortOptions;
46use arrow::datatypes::SchemaRef;
47use datafusion_common::{
48    internal_err, plan_err, JoinSide, JoinType, NullEquality, Result,
49};
50use datafusion_execution::memory_pool::MemoryConsumer;
51use datafusion_execution::TaskContext;
52use datafusion_physical_expr::equivalence::join_equivalence_properties;
53use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef};
54use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
55
56/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge
57/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large
58/// inputs where one or both of the inputs don't fit in the available memory.
59///
60/// # Join Expressions
61///
62/// Equi-join predicate (e.g. `<col1> = <col2>`) expressions are represented by [`Self::on`].
63///
64/// Non-equality predicates, which can not be pushed down to join inputs (e.g.
65/// `<col1> != <col2>`) are known as "filter expressions" and are evaluated
66/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional
67/// expressions.
68///
69/// # Sorting
70///
71/// Assumes that both the left and right input to the join are pre-sorted. It is not the
72/// responsibility of this execution plan to sort the inputs.
73///
74/// # "Streamed" vs "Buffered"
75///
76/// The number of record batches of streamed input currently present in the memory will depend
77/// on the output batch size of the execution plan. There is no spilling support for streamed input.
78/// The comparisons are performed from values of join keys in streamed input with the values of
79/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in
80/// buffered input batches. The streamed input is managed through the states in `StreamedState`
81/// and streamed input batches are represented by `StreamedBatch`.
82///
83/// Buffered input is buffered for all record batches having the same value of join key.
84/// If the memory limit increases beyond the specified value and spilling is enabled,
85/// buffered batches could be spilled to disk. If spilling is disabled, the execution
86/// will fail under the same conditions. Multiple record batches of buffered could currently reside
87/// in memory/disk during the execution. The number of buffered batches residing in
88/// memory/disk depends on the number of rows of buffered input having the same value
89/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs,
90/// the algorithm understands when it is not needed anymore, and releases the buffered batches
91/// from memory/disk. The buffered input is managed through the states in `BufferedState`
92/// and buffered input batches are represented by `BufferedBatch`.
93///
94/// Depending on the type of join, left or right input may be selected as streamed or buffered
95/// respectively. For example, in a left-outer join, the left execution plan will be selected as
96/// streamed input while in a right-outer join, the right execution plan will be selected as the
97/// streamed input.
98///
99/// Reference for the algorithm:
100/// <https://en.wikipedia.org/wiki/Sort-merge_join>.
101///
102/// Helpful short video demonstration:
103/// <https://www.youtube.com/watch?v=jiWCPJtDE2c>.
104#[derive(Debug, Clone)]
105pub struct SortMergeJoinExec {
106    /// Left sorted joining execution plan
107    pub left: Arc<dyn ExecutionPlan>,
108    /// Right sorting joining execution plan
109    pub right: Arc<dyn ExecutionPlan>,
110    /// Set of common columns used to join on
111    pub on: JoinOn,
112    /// Filters which are applied while finding matching rows
113    pub filter: Option<JoinFilter>,
114    /// How the join is performed
115    pub join_type: JoinType,
116    /// The schema once the join is applied
117    schema: SchemaRef,
118    /// Execution metrics
119    metrics: ExecutionPlanMetricsSet,
120    /// The left SortExpr
121    left_sort_exprs: LexOrdering,
122    /// The right SortExpr
123    right_sort_exprs: LexOrdering,
124    /// Sort options of join columns used in sorting left and right execution plans
125    pub sort_options: Vec<SortOptions>,
126    /// Defines the null equality for the join.
127    pub null_equality: NullEquality,
128    /// Cache holding plan properties like equivalences, output partitioning etc.
129    cache: PlanProperties,
130}
131
132impl SortMergeJoinExec {
133    /// Tries to create a new [SortMergeJoinExec].
134    /// The inputs are sorted using `sort_options` are applied to the columns in the `on`
135    /// # Error
136    /// This function errors when it is not possible to join the left and right sides on keys `on`.
137    pub fn try_new(
138        left: Arc<dyn ExecutionPlan>,
139        right: Arc<dyn ExecutionPlan>,
140        on: JoinOn,
141        filter: Option<JoinFilter>,
142        join_type: JoinType,
143        sort_options: Vec<SortOptions>,
144        null_equality: NullEquality,
145    ) -> Result<Self> {
146        let left_schema = left.schema();
147        let right_schema = right.schema();
148
149        check_join_is_valid(&left_schema, &right_schema, &on)?;
150        if sort_options.len() != on.len() {
151            return plan_err!(
152                "Expected number of sort options: {}, actual: {}",
153                on.len(),
154                sort_options.len()
155            );
156        }
157
158        let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
159            .iter()
160            .zip(sort_options.iter())
161            .map(|((l, r), sort_op)| {
162                let left = PhysicalSortExpr {
163                    expr: Arc::clone(l),
164                    options: *sort_op,
165                };
166                let right = PhysicalSortExpr {
167                    expr: Arc::clone(r),
168                    options: *sort_op,
169                };
170                (left, right)
171            })
172            .unzip();
173        let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
174            return plan_err!(
175                "SortMergeJoinExec requires valid sort expressions for its left side"
176            );
177        };
178        let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
179            return plan_err!(
180                "SortMergeJoinExec requires valid sort expressions for its right side"
181            );
182        };
183
184        let schema =
185            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
186        let cache =
187            Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
188        Ok(Self {
189            left,
190            right,
191            on,
192            filter,
193            join_type,
194            schema,
195            metrics: ExecutionPlanMetricsSet::new(),
196            left_sort_exprs,
197            right_sort_exprs,
198            sort_options,
199            null_equality,
200            cache,
201        })
202    }
203
204    /// Get probe side (e.g streaming side) information for this sort merge join.
205    /// In current implementation, probe side is determined according to join type.
206    pub fn probe_side(join_type: &JoinType) -> JoinSide {
207        // When output schema contains only the right side, probe side is right.
208        // Otherwise probe side is the left side.
209        match join_type {
210            // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226)
211            JoinType::Right
212            | JoinType::RightSemi
213            | JoinType::RightAnti
214            | JoinType::RightMark => JoinSide::Right,
215            JoinType::Inner
216            | JoinType::Left
217            | JoinType::Full
218            | JoinType::LeftAnti
219            | JoinType::LeftSemi
220            | JoinType::LeftMark => JoinSide::Left,
221        }
222    }
223
224    /// Calculate order preservation flags for this sort merge join.
225    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
226        match join_type {
227            JoinType::Inner => vec![true, false],
228            JoinType::Left
229            | JoinType::LeftSemi
230            | JoinType::LeftAnti
231            | JoinType::LeftMark => vec![true, false],
232            JoinType::Right
233            | JoinType::RightSemi
234            | JoinType::RightAnti
235            | JoinType::RightMark => {
236                vec![false, true]
237            }
238            _ => vec![false, false],
239        }
240    }
241
242    /// Set of common columns used to join on
243    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
244        &self.on
245    }
246
247    /// Ref to right execution plan
248    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
249        &self.right
250    }
251
252    /// Join type
253    pub fn join_type(&self) -> JoinType {
254        self.join_type
255    }
256
257    /// Ref to left execution plan
258    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
259        &self.left
260    }
261
262    /// Ref to join filter
263    pub fn filter(&self) -> &Option<JoinFilter> {
264        &self.filter
265    }
266
267    /// Ref to sort options
268    pub fn sort_options(&self) -> &[SortOptions] {
269        &self.sort_options
270    }
271
272    /// Null equality
273    pub fn null_equality(&self) -> NullEquality {
274        self.null_equality
275    }
276
277    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
278    fn compute_properties(
279        left: &Arc<dyn ExecutionPlan>,
280        right: &Arc<dyn ExecutionPlan>,
281        schema: SchemaRef,
282        join_type: JoinType,
283        join_on: JoinOnRef,
284    ) -> Result<PlanProperties> {
285        // Calculate equivalence properties:
286        let eq_properties = join_equivalence_properties(
287            left.equivalence_properties().clone(),
288            right.equivalence_properties().clone(),
289            &join_type,
290            schema,
291            &Self::maintains_input_order(join_type),
292            Some(Self::probe_side(&join_type)),
293            join_on,
294        )?;
295
296        let output_partitioning =
297            symmetric_join_output_partitioning(left, right, &join_type)?;
298
299        Ok(PlanProperties::new(
300            eq_properties,
301            output_partitioning,
302            EmissionType::Incremental,
303            boundedness_from_children([left, right]),
304        ))
305    }
306
307    /// # Notes:
308    ///
309    /// This function should be called BEFORE inserting any repartitioning
310    /// operators on the join's children. Check [`super::super::HashJoinExec::swap_inputs`]
311    /// for more details.
312    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
313        let left = self.left();
314        let right = self.right();
315        let new_join = SortMergeJoinExec::try_new(
316            Arc::clone(right),
317            Arc::clone(left),
318            self.on()
319                .iter()
320                .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
321                .collect::<Vec<_>>(),
322            self.filter().as_ref().map(JoinFilter::swap),
323            self.join_type().swap(),
324            self.sort_options.clone(),
325            self.null_equality,
326        )?;
327
328        // TODO: OR this condition with having a built-in projection (like
329        //       ordinary hash join) when we support it.
330        if matches!(
331            self.join_type(),
332            JoinType::LeftSemi
333                | JoinType::RightSemi
334                | JoinType::LeftAnti
335                | JoinType::RightAnti
336        ) {
337            Ok(Arc::new(new_join))
338        } else {
339            reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
340        }
341    }
342}
343
344impl DisplayAs for SortMergeJoinExec {
345    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
346        match t {
347            DisplayFormatType::Default | DisplayFormatType::Verbose => {
348                let on = self
349                    .on
350                    .iter()
351                    .map(|(c1, c2)| format!("({c1}, {c2})"))
352                    .collect::<Vec<String>>()
353                    .join(", ");
354                write!(
355                    f,
356                    "SortMergeJoin: join_type={:?}, on=[{}]{}",
357                    self.join_type,
358                    on,
359                    self.filter.as_ref().map_or("".to_string(), |f| format!(
360                        ", filter={}",
361                        f.expression()
362                    ))
363                )
364            }
365            DisplayFormatType::TreeRender => {
366                let on = self
367                    .on
368                    .iter()
369                    .map(|(c1, c2)| {
370                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
371                    })
372                    .collect::<Vec<String>>()
373                    .join(", ");
374
375                if self.join_type() != JoinType::Inner {
376                    writeln!(f, "join_type={:?}", self.join_type)?;
377                }
378                writeln!(f, "on={on}")
379            }
380        }
381    }
382}
383
384impl ExecutionPlan for SortMergeJoinExec {
385    fn name(&self) -> &'static str {
386        "SortMergeJoinExec"
387    }
388
389    fn as_any(&self) -> &dyn Any {
390        self
391    }
392
393    fn properties(&self) -> &PlanProperties {
394        &self.cache
395    }
396
397    fn required_input_distribution(&self) -> Vec<Distribution> {
398        let (left_expr, right_expr) = self
399            .on
400            .iter()
401            .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
402            .unzip();
403        vec![
404            Distribution::HashPartitioned(left_expr),
405            Distribution::HashPartitioned(right_expr),
406        ]
407    }
408
409    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
410        vec![
411            Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
412            Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
413        ]
414    }
415
416    fn maintains_input_order(&self) -> Vec<bool> {
417        Self::maintains_input_order(self.join_type)
418    }
419
420    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
421        vec![&self.left, &self.right]
422    }
423
424    fn with_new_children(
425        self: Arc<Self>,
426        children: Vec<Arc<dyn ExecutionPlan>>,
427    ) -> Result<Arc<dyn ExecutionPlan>> {
428        match &children[..] {
429            [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
430                Arc::clone(left),
431                Arc::clone(right),
432                self.on.clone(),
433                self.filter.clone(),
434                self.join_type,
435                self.sort_options.clone(),
436                self.null_equality,
437            )?)),
438            _ => internal_err!("SortMergeJoin wrong number of children"),
439        }
440    }
441
442    fn execute(
443        &self,
444        partition: usize,
445        context: Arc<TaskContext>,
446    ) -> Result<SendableRecordBatchStream> {
447        let left_partitions = self.left.output_partitioning().partition_count();
448        let right_partitions = self.right.output_partitioning().partition_count();
449        if left_partitions != right_partitions {
450            return internal_err!(
451                "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
452                 consider using RepartitionExec"
453            );
454        }
455        let (on_left, on_right) = self.on.iter().cloned().unzip();
456        let (streamed, buffered, on_streamed, on_buffered) =
457            if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
458                (
459                    Arc::clone(&self.left),
460                    Arc::clone(&self.right),
461                    on_left,
462                    on_right,
463                )
464            } else {
465                (
466                    Arc::clone(&self.right),
467                    Arc::clone(&self.left),
468                    on_right,
469                    on_left,
470                )
471            };
472
473        // execute children plans
474        let streamed = streamed.execute(partition, Arc::clone(&context))?;
475        let buffered = buffered.execute(partition, Arc::clone(&context))?;
476
477        // create output buffer
478        let batch_size = context.session_config().batch_size();
479
480        // create memory reservation
481        let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
482            .register(context.memory_pool());
483
484        // create join stream
485        Ok(Box::pin(SortMergeJoinStream::try_new(
486            context.session_config().spill_compression(),
487            Arc::clone(&self.schema),
488            self.sort_options.clone(),
489            self.null_equality,
490            streamed,
491            buffered,
492            on_streamed,
493            on_buffered,
494            self.filter.clone(),
495            self.join_type,
496            batch_size,
497            SortMergeJoinMetrics::new(partition, &self.metrics),
498            reservation,
499            context.runtime_env(),
500        )?))
501    }
502
503    fn metrics(&self) -> Option<MetricsSet> {
504        Some(self.metrics.clone_inner())
505    }
506
507    fn statistics(&self) -> Result<Statistics> {
508        self.partition_statistics(None)
509    }
510
511    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
512        if partition.is_some() {
513            return Ok(Statistics::new_unknown(&self.schema()));
514        }
515        // TODO stats: it is not possible in general to know the output size of joins
516        // There are some special cases though, for example:
517        // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
518        estimate_join_statistics(
519            self.left.partition_statistics(None)?,
520            self.right.partition_statistics(None)?,
521            self.on.clone(),
522            &self.join_type,
523            &self.schema,
524        )
525    }
526
527    /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done,
528    /// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan.
529    /// Otherwise, it returns None.
530    fn try_swapping_with_projection(
531        &self,
532        projection: &ProjectionExec,
533    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
534        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
535        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
536        else {
537            return Ok(None);
538        };
539
540        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
541            self.left().schema().fields().len(),
542            &projection_as_columns,
543        );
544
545        if !join_allows_pushdown(
546            &projection_as_columns,
547            &self.schema(),
548            far_right_left_col_ind,
549            far_left_right_col_ind,
550        ) {
551            return Ok(None);
552        }
553
554        let Some(new_on) = update_join_on(
555            &projection_as_columns[0..=far_right_left_col_ind as _],
556            &projection_as_columns[far_left_right_col_ind as _..],
557            self.on(),
558            self.left().schema().fields().len(),
559        ) else {
560            return Ok(None);
561        };
562
563        let (new_left, new_right) = new_join_children(
564            &projection_as_columns,
565            far_right_left_col_ind,
566            far_left_right_col_ind,
567            self.children()[0],
568            self.children()[1],
569        )?;
570
571        Ok(Some(Arc::new(SortMergeJoinExec::try_new(
572            Arc::new(new_left),
573            Arc::new(new_right),
574            new_on,
575            self.filter.clone(),
576            self.join_type,
577            self.sort_options.clone(),
578            self.null_equality,
579        )?)))
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use std::sync::Arc;
586
587    use arrow::array::{
588        builder::{BooleanBuilder, UInt64Builder},
589        BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray,
590        Int32Array, RecordBatch, UInt64Array,
591    };
592    use arrow::compute::{concat_batches, filter_record_batch, SortOptions};
593    use arrow::datatypes::{DataType, Field, Schema};
594
595    use datafusion_common::JoinType::*;
596    use datafusion_common::{
597        assert_batches_eq, assert_contains, JoinType, NullEquality, Result,
598    };
599    use datafusion_common::{
600        test_util::{batches_to_sort_string, batches_to_string},
601        JoinSide,
602    };
603    use datafusion_execution::config::SessionConfig;
604    use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
605    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
606    use datafusion_execution::TaskContext;
607    use datafusion_expr::Operator;
608    use datafusion_physical_expr::expressions::BinaryExpr;
609    use insta::{allow_duplicates, assert_snapshot};
610
611    use crate::{
612        expressions::Column,
613        joins::sort_merge_join::stream::{
614            get_corrected_filter_mask, JoinedRecordBatches,
615        },
616    };
617
618    use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
619    use crate::joins::SortMergeJoinExec;
620    use crate::test::TestMemoryExec;
621    use crate::test::{build_table_i32, build_table_i32_two_cols};
622    use crate::{common, ExecutionPlan};
623
624    fn build_table(
625        a: (&str, &Vec<i32>),
626        b: (&str, &Vec<i32>),
627        c: (&str, &Vec<i32>),
628    ) -> Arc<dyn ExecutionPlan> {
629        let batch = build_table_i32(a, b, c);
630        let schema = batch.schema();
631        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
632    }
633
634    fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
635        let schema = batches.first().unwrap().schema();
636        TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap()
637    }
638
639    fn build_date_table(
640        a: (&str, &Vec<i32>),
641        b: (&str, &Vec<i32>),
642        c: (&str, &Vec<i32>),
643    ) -> Arc<dyn ExecutionPlan> {
644        let schema = Schema::new(vec![
645            Field::new(a.0, DataType::Date32, false),
646            Field::new(b.0, DataType::Date32, false),
647            Field::new(c.0, DataType::Date32, false),
648        ]);
649
650        let batch = RecordBatch::try_new(
651            Arc::new(schema),
652            vec![
653                Arc::new(Date32Array::from(a.1.clone())),
654                Arc::new(Date32Array::from(b.1.clone())),
655                Arc::new(Date32Array::from(c.1.clone())),
656            ],
657        )
658        .unwrap();
659
660        let schema = batch.schema();
661        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
662    }
663
664    fn build_date64_table(
665        a: (&str, &Vec<i64>),
666        b: (&str, &Vec<i64>),
667        c: (&str, &Vec<i64>),
668    ) -> Arc<dyn ExecutionPlan> {
669        let schema = Schema::new(vec![
670            Field::new(a.0, DataType::Date64, false),
671            Field::new(b.0, DataType::Date64, false),
672            Field::new(c.0, DataType::Date64, false),
673        ]);
674
675        let batch = RecordBatch::try_new(
676            Arc::new(schema),
677            vec![
678                Arc::new(Date64Array::from(a.1.clone())),
679                Arc::new(Date64Array::from(b.1.clone())),
680                Arc::new(Date64Array::from(c.1.clone())),
681            ],
682        )
683        .unwrap();
684
685        let schema = batch.schema();
686        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
687    }
688
689    fn build_binary_table(
690        a: (&str, &Vec<&[u8]>),
691        b: (&str, &Vec<i32>),
692        c: (&str, &Vec<i32>),
693    ) -> Arc<dyn ExecutionPlan> {
694        let schema = Schema::new(vec![
695            Field::new(a.0, DataType::Binary, false),
696            Field::new(b.0, DataType::Int32, false),
697            Field::new(c.0, DataType::Int32, false),
698        ]);
699
700        let batch = RecordBatch::try_new(
701            Arc::new(schema),
702            vec![
703                Arc::new(BinaryArray::from(a.1.clone())),
704                Arc::new(Int32Array::from(b.1.clone())),
705                Arc::new(Int32Array::from(c.1.clone())),
706            ],
707        )
708        .unwrap();
709
710        let schema = batch.schema();
711        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
712    }
713
714    fn build_fixed_size_binary_table(
715        a: (&str, &Vec<&[u8]>),
716        b: (&str, &Vec<i32>),
717        c: (&str, &Vec<i32>),
718    ) -> Arc<dyn ExecutionPlan> {
719        let schema = Schema::new(vec![
720            Field::new(a.0, DataType::FixedSizeBinary(3), false),
721            Field::new(b.0, DataType::Int32, false),
722            Field::new(c.0, DataType::Int32, false),
723        ]);
724
725        let batch = RecordBatch::try_new(
726            Arc::new(schema),
727            vec![
728                Arc::new(FixedSizeBinaryArray::from(a.1.clone())),
729                Arc::new(Int32Array::from(b.1.clone())),
730                Arc::new(Int32Array::from(c.1.clone())),
731            ],
732        )
733        .unwrap();
734
735        let schema = batch.schema();
736        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
737    }
738
739    /// returns a table with 3 columns of i32 in memory
740    pub fn build_table_i32_nullable(
741        a: (&str, &Vec<Option<i32>>),
742        b: (&str, &Vec<Option<i32>>),
743        c: (&str, &Vec<Option<i32>>),
744    ) -> Arc<dyn ExecutionPlan> {
745        let schema = Arc::new(Schema::new(vec![
746            Field::new(a.0, DataType::Int32, true),
747            Field::new(b.0, DataType::Int32, true),
748            Field::new(c.0, DataType::Int32, true),
749        ]));
750        let batch = RecordBatch::try_new(
751            Arc::clone(&schema),
752            vec![
753                Arc::new(Int32Array::from(a.1.clone())),
754                Arc::new(Int32Array::from(b.1.clone())),
755                Arc::new(Int32Array::from(c.1.clone())),
756            ],
757        )
758        .unwrap();
759        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
760    }
761
762    pub fn build_table_two_cols(
763        a: (&str, &Vec<i32>),
764        b: (&str, &Vec<i32>),
765    ) -> Arc<dyn ExecutionPlan> {
766        let batch = build_table_i32_two_cols(a, b);
767        let schema = batch.schema();
768        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
769    }
770
771    fn join(
772        left: Arc<dyn ExecutionPlan>,
773        right: Arc<dyn ExecutionPlan>,
774        on: JoinOn,
775        join_type: JoinType,
776    ) -> Result<SortMergeJoinExec> {
777        let sort_options = vec![SortOptions::default(); on.len()];
778        SortMergeJoinExec::try_new(
779            left,
780            right,
781            on,
782            None,
783            join_type,
784            sort_options,
785            NullEquality::NullEqualsNothing,
786        )
787    }
788
789    fn join_with_options(
790        left: Arc<dyn ExecutionPlan>,
791        right: Arc<dyn ExecutionPlan>,
792        on: JoinOn,
793        join_type: JoinType,
794        sort_options: Vec<SortOptions>,
795        null_equality: NullEquality,
796    ) -> Result<SortMergeJoinExec> {
797        SortMergeJoinExec::try_new(
798            left,
799            right,
800            on,
801            None,
802            join_type,
803            sort_options,
804            null_equality,
805        )
806    }
807
808    fn join_with_filter(
809        left: Arc<dyn ExecutionPlan>,
810        right: Arc<dyn ExecutionPlan>,
811        on: JoinOn,
812        filter: JoinFilter,
813        join_type: JoinType,
814        sort_options: Vec<SortOptions>,
815        null_equality: NullEquality,
816    ) -> Result<SortMergeJoinExec> {
817        SortMergeJoinExec::try_new(
818            left,
819            right,
820            on,
821            Some(filter),
822            join_type,
823            sort_options,
824            null_equality,
825        )
826    }
827
828    async fn join_collect(
829        left: Arc<dyn ExecutionPlan>,
830        right: Arc<dyn ExecutionPlan>,
831        on: JoinOn,
832        join_type: JoinType,
833    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
834        let sort_options = vec![SortOptions::default(); on.len()];
835        join_collect_with_options(
836            left,
837            right,
838            on,
839            join_type,
840            sort_options,
841            NullEquality::NullEqualsNothing,
842        )
843        .await
844    }
845
846    async fn join_collect_with_filter(
847        left: Arc<dyn ExecutionPlan>,
848        right: Arc<dyn ExecutionPlan>,
849        on: JoinOn,
850        filter: JoinFilter,
851        join_type: JoinType,
852    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
853        let sort_options = vec![SortOptions::default(); on.len()];
854
855        let task_ctx = Arc::new(TaskContext::default());
856        let join = join_with_filter(
857            left,
858            right,
859            on,
860            filter,
861            join_type,
862            sort_options,
863            NullEquality::NullEqualsNothing,
864        )?;
865        let columns = columns(&join.schema());
866
867        let stream = join.execute(0, task_ctx)?;
868        let batches = common::collect(stream).await?;
869        Ok((columns, batches))
870    }
871
872    async fn join_collect_with_options(
873        left: Arc<dyn ExecutionPlan>,
874        right: Arc<dyn ExecutionPlan>,
875        on: JoinOn,
876        join_type: JoinType,
877        sort_options: Vec<SortOptions>,
878        null_equality: NullEquality,
879    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
880        let task_ctx = Arc::new(TaskContext::default());
881        let join =
882            join_with_options(left, right, on, join_type, sort_options, null_equality)?;
883        let columns = columns(&join.schema());
884
885        let stream = join.execute(0, task_ctx)?;
886        let batches = common::collect(stream).await?;
887        Ok((columns, batches))
888    }
889
890    async fn join_collect_batch_size_equals_two(
891        left: Arc<dyn ExecutionPlan>,
892        right: Arc<dyn ExecutionPlan>,
893        on: JoinOn,
894        join_type: JoinType,
895    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
896        let task_ctx = TaskContext::default()
897            .with_session_config(SessionConfig::new().with_batch_size(2));
898        let task_ctx = Arc::new(task_ctx);
899        let join = join(left, right, on, join_type)?;
900        let columns = columns(&join.schema());
901
902        let stream = join.execute(0, task_ctx)?;
903        let batches = common::collect(stream).await?;
904        Ok((columns, batches))
905    }
906
907    #[tokio::test]
908    async fn join_inner_one() -> Result<()> {
909        let left = build_table(
910            ("a1", &vec![1, 2, 3]),
911            ("b1", &vec![4, 5, 5]), // this has a repetition
912            ("c1", &vec![7, 8, 9]),
913        );
914        let right = build_table(
915            ("a2", &vec![10, 20, 30]),
916            ("b1", &vec![4, 5, 6]),
917            ("c2", &vec![70, 80, 90]),
918        );
919
920        let on = vec![(
921            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
922            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
923        )];
924
925        let (_, batches) = join_collect(left, right, on, Inner).await?;
926
927        // The output order is important as SMJ preserves sortedness
928        assert_snapshot!(batches_to_string(&batches), @r#"
929            +----+----+----+----+----+----+
930            | a1 | b1 | c1 | a2 | b1 | c2 |
931            +----+----+----+----+----+----+
932            | 1  | 4  | 7  | 10 | 4  | 70 |
933            | 2  | 5  | 8  | 20 | 5  | 80 |
934            | 3  | 5  | 9  | 20 | 5  | 80 |
935            +----+----+----+----+----+----+
936            "#);
937        Ok(())
938    }
939
940    #[tokio::test]
941    async fn join_inner_two() -> Result<()> {
942        let left = build_table(
943            ("a1", &vec![1, 2, 2]),
944            ("b2", &vec![1, 2, 2]),
945            ("c1", &vec![7, 8, 9]),
946        );
947        let right = build_table(
948            ("a1", &vec![1, 2, 3]),
949            ("b2", &vec![1, 2, 2]),
950            ("c2", &vec![70, 80, 90]),
951        );
952        let on = vec![
953            (
954                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
955                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
956            ),
957            (
958                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
959                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
960            ),
961        ];
962
963        let (_columns, batches) = join_collect(left, right, on, Inner).await?;
964
965        // The output order is important as SMJ preserves sortedness
966        assert_snapshot!(batches_to_string(&batches), @r#"
967            +----+----+----+----+----+----+
968            | a1 | b2 | c1 | a1 | b2 | c2 |
969            +----+----+----+----+----+----+
970            | 1  | 1  | 7  | 1  | 1  | 70 |
971            | 2  | 2  | 8  | 2  | 2  | 80 |
972            | 2  | 2  | 9  | 2  | 2  | 80 |
973            +----+----+----+----+----+----+
974            "#);
975        Ok(())
976    }
977
978    #[tokio::test]
979    async fn join_inner_two_two() -> Result<()> {
980        let left = build_table(
981            ("a1", &vec![1, 1, 2]),
982            ("b2", &vec![1, 1, 2]),
983            ("c1", &vec![7, 8, 9]),
984        );
985        let right = build_table(
986            ("a1", &vec![1, 1, 3]),
987            ("b2", &vec![1, 1, 2]),
988            ("c2", &vec![70, 80, 90]),
989        );
990        let on = vec![
991            (
992                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
993                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
994            ),
995            (
996                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
997                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
998            ),
999        ];
1000
1001        let (_columns, batches) = join_collect(left, right, on, Inner).await?;
1002
1003        // The output order is important as SMJ preserves sortedness
1004        assert_snapshot!(batches_to_string(&batches), @r#"
1005            +----+----+----+----+----+----+
1006            | a1 | b2 | c1 | a1 | b2 | c2 |
1007            +----+----+----+----+----+----+
1008            | 1  | 1  | 7  | 1  | 1  | 70 |
1009            | 1  | 1  | 7  | 1  | 1  | 80 |
1010            | 1  | 1  | 8  | 1  | 1  | 70 |
1011            | 1  | 1  | 8  | 1  | 1  | 80 |
1012            +----+----+----+----+----+----+
1013            "#);
1014        Ok(())
1015    }
1016
1017    #[tokio::test]
1018    async fn join_inner_with_nulls() -> Result<()> {
1019        let left = build_table_i32_nullable(
1020            ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]),
1021            ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field
1022            ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field
1023        );
1024        let right = build_table_i32_nullable(
1025            ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]),
1026            ("b2", &vec![None, Some(1), Some(2), Some(2)]),
1027            ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]),
1028        );
1029        let on = vec![
1030            (
1031                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1032                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1033            ),
1034            (
1035                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
1036                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
1037            ),
1038        ];
1039
1040        let (_, batches) = join_collect(left, right, on, Inner).await?;
1041        // The output order is important as SMJ preserves sortedness
1042        assert_snapshot!(batches_to_string(&batches), @r#"
1043            +----+----+----+----+----+----+
1044            | a1 | b2 | c1 | a1 | b2 | c2 |
1045            +----+----+----+----+----+----+
1046            | 1  | 1  |    | 1  | 1  | 70 |
1047            | 2  | 2  | 8  | 2  | 2  | 80 |
1048            | 2  | 2  | 9  | 2  | 2  | 80 |
1049            +----+----+----+----+----+----+
1050            "#);
1051        Ok(())
1052    }
1053
1054    #[tokio::test]
1055    async fn join_inner_with_nulls_with_options() -> Result<()> {
1056        let left = build_table_i32_nullable(
1057            ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]),
1058            ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field
1059            ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field
1060        );
1061        let right = build_table_i32_nullable(
1062            ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]),
1063            ("b2", &vec![Some(2), Some(2), Some(1), None]),
1064            ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]),
1065        );
1066        let on = vec![
1067            (
1068                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1069                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1070            ),
1071            (
1072                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
1073                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
1074            ),
1075        ];
1076        let (_, batches) = join_collect_with_options(
1077            left,
1078            right,
1079            on,
1080            Inner,
1081            vec![
1082                SortOptions {
1083                    descending: true,
1084                    nulls_first: false,
1085                };
1086                2
1087            ],
1088            NullEquality::NullEqualsNull,
1089        )
1090        .await?;
1091        // The output order is important as SMJ preserves sortedness
1092        assert_snapshot!(batches_to_string(&batches), @r#"
1093            +----+----+----+----+----+----+
1094            | a1 | b2 | c1 | a1 | b2 | c2 |
1095            +----+----+----+----+----+----+
1096            | 2  | 2  | 9  | 2  | 2  | 80 |
1097            | 2  | 2  | 8  | 2  | 2  | 80 |
1098            | 1  | 1  |    | 1  | 1  | 70 |
1099            | 1  |    | 1  | 1  |    | 10 |
1100            +----+----+----+----+----+----+
1101            "#);
1102        Ok(())
1103    }
1104
1105    #[tokio::test]
1106    async fn join_inner_output_two_batches() -> Result<()> {
1107        let left = build_table(
1108            ("a1", &vec![1, 2, 2]),
1109            ("b2", &vec![1, 2, 2]),
1110            ("c1", &vec![7, 8, 9]),
1111        );
1112        let right = build_table(
1113            ("a1", &vec![1, 2, 3]),
1114            ("b2", &vec![1, 2, 2]),
1115            ("c2", &vec![70, 80, 90]),
1116        );
1117        let on = vec![
1118            (
1119                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1120                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1121            ),
1122            (
1123                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
1124                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
1125            ),
1126        ];
1127
1128        let (_, batches) =
1129            join_collect_batch_size_equals_two(left, right, on, Inner).await?;
1130        assert_eq!(batches.len(), 2);
1131        assert_eq!(batches[0].num_rows(), 2);
1132        assert_eq!(batches[1].num_rows(), 1);
1133        // The output order is important as SMJ preserves sortedness
1134        assert_snapshot!(batches_to_string(&batches), @r#"
1135            +----+----+----+----+----+----+
1136            | a1 | b2 | c1 | a1 | b2 | c2 |
1137            +----+----+----+----+----+----+
1138            | 1  | 1  | 7  | 1  | 1  | 70 |
1139            | 2  | 2  | 8  | 2  | 2  | 80 |
1140            | 2  | 2  | 9  | 2  | 2  | 80 |
1141            +----+----+----+----+----+----+
1142            "#);
1143        Ok(())
1144    }
1145
1146    #[tokio::test]
1147    async fn join_left_one() -> Result<()> {
1148        let left = build_table(
1149            ("a1", &vec![1, 2, 3]),
1150            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
1151            ("c1", &vec![7, 8, 9]),
1152        );
1153        let right = build_table(
1154            ("a2", &vec![10, 20, 30]),
1155            ("b1", &vec![4, 5, 6]),
1156            ("c2", &vec![70, 80, 90]),
1157        );
1158        let on = vec![(
1159            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1160            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1161        )];
1162
1163        let (_, batches) = join_collect(left, right, on, Left).await?;
1164        // The output order is important as SMJ preserves sortedness
1165        assert_snapshot!(batches_to_string(&batches), @r#"
1166            +----+----+----+----+----+----+
1167            | a1 | b1 | c1 | a2 | b1 | c2 |
1168            +----+----+----+----+----+----+
1169            | 1  | 4  | 7  | 10 | 4  | 70 |
1170            | 2  | 5  | 8  | 20 | 5  | 80 |
1171            | 3  | 7  | 9  |    |    |    |
1172            +----+----+----+----+----+----+
1173            "#);
1174        Ok(())
1175    }
1176
1177    #[tokio::test]
1178    async fn join_right_one() -> Result<()> {
1179        let left = build_table(
1180            ("a1", &vec![1, 2, 3]),
1181            ("b1", &vec![4, 5, 7]),
1182            ("c1", &vec![7, 8, 9]),
1183        );
1184        let right = build_table(
1185            ("a2", &vec![10, 20, 30]),
1186            ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
1187            ("c2", &vec![70, 80, 90]),
1188        );
1189        let on = vec![(
1190            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1191            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1192        )];
1193
1194        let (_, batches) = join_collect(left, right, on, Right).await?;
1195        // The output order is important as SMJ preserves sortedness
1196        assert_snapshot!(batches_to_string(&batches), @r#"
1197            +----+----+----+----+----+----+
1198            | a1 | b1 | c1 | a2 | b1 | c2 |
1199            +----+----+----+----+----+----+
1200            | 1  | 4  | 7  | 10 | 4  | 70 |
1201            | 2  | 5  | 8  | 20 | 5  | 80 |
1202            |    |    |    | 30 | 6  | 90 |
1203            +----+----+----+----+----+----+
1204            "#);
1205        Ok(())
1206    }
1207
1208    #[tokio::test]
1209    async fn join_full_one() -> Result<()> {
1210        let left = build_table(
1211            ("a1", &vec![1, 2, 3]),
1212            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
1213            ("c1", &vec![7, 8, 9]),
1214        );
1215        let right = build_table(
1216            ("a2", &vec![10, 20, 30]),
1217            ("b2", &vec![4, 5, 6]),
1218            ("c2", &vec![70, 80, 90]),
1219        );
1220        let on = vec![(
1221            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
1222            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
1223        )];
1224
1225        let (_, batches) = join_collect(left, right, on, Full).await?;
1226        // The output order is important as SMJ preserves sortedness
1227        assert_snapshot!(batches_to_sort_string(&batches), @r#"
1228            +----+----+----+----+----+----+
1229            | a1 | b1 | c1 | a2 | b2 | c2 |
1230            +----+----+----+----+----+----+
1231            |    |    |    | 30 | 6  | 90 |
1232            | 1  | 4  | 7  | 10 | 4  | 70 |
1233            | 2  | 5  | 8  | 20 | 5  | 80 |
1234            | 3  | 7  | 9  |    |    |    |
1235            +----+----+----+----+----+----+
1236            "#);
1237        Ok(())
1238    }
1239
1240    #[tokio::test]
1241    async fn join_left_anti() -> Result<()> {
1242        let left = build_table(
1243            ("a1", &vec![1, 2, 2, 3, 5]),
1244            ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
1245            ("c1", &vec![7, 8, 8, 9, 11]),
1246        );
1247        let right = build_table(
1248            ("a2", &vec![10, 20, 30]),
1249            ("b1", &vec![4, 5, 6]),
1250            ("c2", &vec![70, 80, 90]),
1251        );
1252        let on = vec![(
1253            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1254            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1255        )];
1256
1257        let (_, batches) = join_collect(left, right, on, LeftAnti).await?;
1258
1259        // The output order is important as SMJ preserves sortedness
1260        assert_snapshot!(batches_to_string(&batches), @r#"
1261            +----+----+----+
1262            | a1 | b1 | c1 |
1263            +----+----+----+
1264            | 3  | 7  | 9  |
1265            | 5  | 7  | 11 |
1266            +----+----+----+
1267            "#);
1268        Ok(())
1269    }
1270
1271    #[tokio::test]
1272    async fn join_right_anti_one_one() -> Result<()> {
1273        let left = build_table(
1274            ("a1", &vec![1, 2, 2]),
1275            ("b1", &vec![4, 5, 5]),
1276            ("c1", &vec![7, 8, 8]),
1277        );
1278        let right =
1279            build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]));
1280        let on = vec![(
1281            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1282            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1283        )];
1284
1285        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
1286        // The output order is important as SMJ preserves sortedness
1287        assert_snapshot!(batches_to_string(&batches), @r#"
1288            +----+----+
1289            | a2 | b1 |
1290            +----+----+
1291            | 30 | 6  |
1292            +----+----+
1293            "#);
1294
1295        let left2 = build_table(
1296            ("a1", &vec![1, 2, 2]),
1297            ("b1", &vec![4, 5, 5]),
1298            ("c1", &vec![7, 8, 8]),
1299        );
1300        let right2 = build_table(
1301            ("a2", &vec![10, 20, 30]),
1302            ("b1", &vec![4, 5, 6]),
1303            ("c2", &vec![70, 80, 90]),
1304        );
1305
1306        let on = vec![(
1307            Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _,
1308            Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _,
1309        )];
1310
1311        let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?;
1312        // The output order is important as SMJ preserves sortedness
1313        assert_snapshot!(batches_to_string(&batches2), @r#"
1314            +----+----+----+
1315            | a2 | b1 | c2 |
1316            +----+----+----+
1317            | 30 | 6  | 90 |
1318            +----+----+----+
1319            "#);
1320
1321        Ok(())
1322    }
1323
1324    #[tokio::test]
1325    async fn join_right_anti_two_two() -> Result<()> {
1326        let left = build_table(
1327            ("a1", &vec![1, 2, 2]),
1328            ("b1", &vec![4, 5, 5]),
1329            ("c1", &vec![7, 8, 8]),
1330        );
1331        let right =
1332            build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]));
1333        let on = vec![
1334            (
1335                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1336                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
1337            ),
1338            (
1339                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1340                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1341            ),
1342        ];
1343
1344        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
1345        // The output order is important as SMJ preserves sortedness
1346        assert_snapshot!(batches_to_string(&batches), @r#"
1347            +----+----+
1348            | a2 | b1 |
1349            +----+----+
1350            | 10 | 4  |
1351            | 20 | 5  |
1352            | 30 | 6  |
1353            +----+----+
1354            "#);
1355
1356        let left = build_table(
1357            ("a1", &vec![1, 2, 2]),
1358            ("b1", &vec![4, 5, 5]),
1359            ("c1", &vec![7, 8, 8]),
1360        );
1361        let right = build_table(
1362            ("a2", &vec![10, 20, 30]),
1363            ("b1", &vec![4, 5, 6]),
1364            ("c2", &vec![70, 80, 90]),
1365        );
1366
1367        let on = vec![
1368            (
1369                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1370                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
1371            ),
1372            (
1373                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1374                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1375            ),
1376        ];
1377
1378        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
1379        let expected = [
1380            "+----+----+----+",
1381            "| a2 | b1 | c2 |",
1382            "+----+----+----+",
1383            "| 10 | 4  | 70 |",
1384            "| 20 | 5  | 80 |",
1385            "| 30 | 6  | 90 |",
1386            "+----+----+----+",
1387        ];
1388        // The output order is important as SMJ preserves sortedness
1389        assert_batches_eq!(expected, &batches);
1390
1391        Ok(())
1392    }
1393
1394    #[tokio::test]
1395    async fn join_right_anti_two_with_filter() -> Result<()> {
1396        let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30]));
1397        let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20]));
1398        let on = vec![
1399            (
1400                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1401                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1402            ),
1403            (
1404                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1405                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1406            ),
1407        ];
1408        let filter = JoinFilter::new(
1409            Arc::new(BinaryExpr::new(
1410                Arc::new(Column::new("c2", 1)),
1411                Operator::Gt,
1412                Arc::new(Column::new("c1", 0)),
1413            )),
1414            vec![
1415                ColumnIndex {
1416                    index: 2,
1417                    side: JoinSide::Left,
1418                },
1419                ColumnIndex {
1420                    index: 2,
1421                    side: JoinSide::Right,
1422                },
1423            ],
1424            Arc::new(Schema::new(vec![
1425                Field::new("c1", DataType::Int32, true),
1426                Field::new("c2", DataType::Int32, true),
1427            ])),
1428        );
1429        let (_, batches) =
1430            join_collect_with_filter(left, right, on, filter, RightAnti).await?;
1431        assert_snapshot!(batches_to_string(&batches), @r#"
1432            +----+----+----+
1433            | a1 | b1 | c2 |
1434            +----+----+----+
1435            | 1  | 10 | 20 |
1436            +----+----+----+
1437            "#);
1438        Ok(())
1439    }
1440
1441    #[tokio::test]
1442    async fn join_right_anti_with_nulls() -> Result<()> {
1443        let left = build_table_i32_nullable(
1444            ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]),
1445            ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]),
1446            ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]),
1447        );
1448        let right = build_table_i32_nullable(
1449            ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]),
1450            ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field
1451            ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field
1452        );
1453        let on = vec![
1454            (
1455                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1456                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1457            ),
1458            (
1459                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1460                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1461            ),
1462        ];
1463
1464        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
1465        // The output order is important as SMJ preserves sortedness
1466        assert_snapshot!(batches_to_string(&batches), @r#"
1467            +----+----+----+
1468            | a1 | b1 | c2 |
1469            +----+----+----+
1470            | 2  |    | 8  |
1471            +----+----+----+
1472            "#);
1473        Ok(())
1474    }
1475
1476    #[tokio::test]
1477    async fn join_right_anti_with_nulls_with_options() -> Result<()> {
1478        let left = build_table_i32_nullable(
1479            ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]),
1480            ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]),
1481            ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]),
1482        );
1483        let right = build_table_i32_nullable(
1484            ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]),
1485            ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field
1486            ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field
1487        );
1488        let on = vec![
1489            (
1490                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1491                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1492            ),
1493            (
1494                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1495                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1496            ),
1497        ];
1498
1499        let (_, batches) = join_collect_with_options(
1500            left,
1501            right,
1502            on,
1503            RightAnti,
1504            vec![
1505                SortOptions {
1506                    descending: true,
1507                    nulls_first: false,
1508                };
1509                2
1510            ],
1511            NullEquality::NullEqualsNull,
1512        )
1513        .await?;
1514
1515        // The output order is important as SMJ preserves sortedness
1516        assert_snapshot!(batches_to_string(&batches), @r#"
1517            +----+----+----+
1518            | a1 | b1 | c2 |
1519            +----+----+----+
1520            | 3  |    | 9  |
1521            | 2  | 5  |    |
1522            | 2  | 5  | 8  |
1523            +----+----+----+
1524            "#);
1525        Ok(())
1526    }
1527
1528    #[tokio::test]
1529    async fn join_right_anti_output_two_batches() -> Result<()> {
1530        let left = build_table(
1531            ("a1", &vec![1, 2, 2]),
1532            ("b1", &vec![4, 5, 5]),
1533            ("c1", &vec![7, 8, 8]),
1534        );
1535        let right = build_table(
1536            ("a2", &vec![10, 20, 30]),
1537            ("b1", &vec![4, 5, 6]),
1538            ("c2", &vec![70, 80, 90]),
1539        );
1540        let on = vec![
1541            (
1542                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1543                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
1544            ),
1545            (
1546                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1547                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1548            ),
1549        ];
1550
1551        let (_, batches) =
1552            join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?;
1553        assert_eq!(batches.len(), 2);
1554        assert_eq!(batches[0].num_rows(), 2);
1555        assert_eq!(batches[1].num_rows(), 1);
1556        assert_snapshot!(batches_to_string(&batches), @r#"
1557            +----+----+----+
1558            | a1 | b1 | c1 |
1559            +----+----+----+
1560            | 1  | 4  | 7  |
1561            | 2  | 5  | 8  |
1562            | 2  | 5  | 8  |
1563            +----+----+----+
1564            "#);
1565        Ok(())
1566    }
1567
1568    #[tokio::test]
1569    async fn join_left_semi() -> Result<()> {
1570        let left = build_table(
1571            ("a1", &vec![1, 2, 2, 3]),
1572            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
1573            ("c1", &vec![7, 8, 8, 9]),
1574        );
1575        let right = build_table(
1576            ("a2", &vec![10, 20, 30]),
1577            ("b1", &vec![4, 5, 6]), // 5 is double on the right
1578            ("c2", &vec![70, 80, 90]),
1579        );
1580        let on = vec![(
1581            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1582            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1583        )];
1584
1585        let (_, batches) = join_collect(left, right, on, LeftSemi).await?;
1586        // The output order is important as SMJ preserves sortedness
1587        assert_snapshot!(batches_to_string(&batches), @r#"
1588            +----+----+----+
1589            | a1 | b1 | c1 |
1590            +----+----+----+
1591            | 1  | 4  | 7  |
1592            | 2  | 5  | 8  |
1593            | 2  | 5  | 8  |
1594            +----+----+----+
1595            "#);
1596        Ok(())
1597    }
1598
1599    #[tokio::test]
1600    async fn join_right_semi_one() -> Result<()> {
1601        let left = build_table(
1602            ("a1", &vec![10, 20, 30, 40]),
1603            ("b1", &vec![4, 5, 5, 6]),
1604            ("c1", &vec![70, 80, 90, 100]),
1605        );
1606        let right = build_table(
1607            ("a2", &vec![1, 2, 2, 3]),
1608            ("b1", &vec![4, 5, 5, 7]),
1609            ("c2", &vec![7, 8, 8, 9]),
1610        );
1611        let on = vec![(
1612            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1613            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1614        )];
1615
1616        let (_, batches) = join_collect(left, right, on, RightSemi).await?;
1617        let expected = [
1618            "+----+----+----+",
1619            "| a2 | b1 | c2 |",
1620            "+----+----+----+",
1621            "| 1  | 4  | 7  |",
1622            "| 2  | 5  | 8  |",
1623            "| 2  | 5  | 8  |",
1624            "+----+----+----+",
1625        ];
1626        assert_batches_eq!(expected, &batches);
1627        Ok(())
1628    }
1629
1630    #[tokio::test]
1631    async fn join_right_semi_two() -> Result<()> {
1632        let left = build_table(
1633            ("a1", &vec![1, 2, 2, 3]),
1634            ("b1", &vec![4, 5, 5, 6]),
1635            ("c1", &vec![70, 80, 90, 100]),
1636        );
1637        let right = build_table(
1638            ("a1", &vec![1, 2, 2, 3]),
1639            ("b1", &vec![4, 5, 5, 7]),
1640            ("c2", &vec![7, 8, 8, 9]),
1641        );
1642        let on = vec![
1643            (
1644                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1645                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1646            ),
1647            (
1648                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1649                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1650            ),
1651        ];
1652
1653        let (_, batches) = join_collect(left, right, on, RightSemi).await?;
1654        let expected = [
1655            "+----+----+----+",
1656            "| a1 | b1 | c2 |",
1657            "+----+----+----+",
1658            "| 1  | 4  | 7  |",
1659            "| 2  | 5  | 8  |",
1660            "| 2  | 5  | 8  |",
1661            "+----+----+----+",
1662        ];
1663        assert_batches_eq!(expected, &batches);
1664        Ok(())
1665    }
1666
1667    #[tokio::test]
1668    async fn join_right_semi_two_with_filter() -> Result<()> {
1669        let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30]));
1670        let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20]));
1671        let on = vec![
1672            (
1673                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1674                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1675            ),
1676            (
1677                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1678                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1679            ),
1680        ];
1681        let filter = JoinFilter::new(
1682            Arc::new(BinaryExpr::new(
1683                Arc::new(Column::new("c2", 1)),
1684                Operator::Lt,
1685                Arc::new(Column::new("c1", 0)),
1686            )),
1687            vec![
1688                ColumnIndex {
1689                    index: 2,
1690                    side: JoinSide::Left,
1691                },
1692                ColumnIndex {
1693                    index: 2,
1694                    side: JoinSide::Right,
1695                },
1696            ],
1697            Arc::new(Schema::new(vec![
1698                Field::new("c1", DataType::Int32, true),
1699                Field::new("c2", DataType::Int32, true),
1700            ])),
1701        );
1702        let (_, batches) =
1703            join_collect_with_filter(left, right, on, filter, RightSemi).await?;
1704        let expected = [
1705            "+----+----+----+",
1706            "| a1 | b1 | c2 |",
1707            "+----+----+----+",
1708            "| 1  | 10 | 20 |",
1709            "+----+----+----+",
1710        ];
1711        assert_batches_eq!(expected, &batches);
1712        Ok(())
1713    }
1714
1715    #[tokio::test]
1716    async fn join_right_semi_with_nulls() -> Result<()> {
1717        let left = build_table_i32_nullable(
1718            ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]),
1719            ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]),
1720            ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]),
1721        );
1722        let right = build_table_i32_nullable(
1723            ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]),
1724            ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field
1725            ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field
1726        );
1727        let on = vec![
1728            (
1729                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1730                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1731            ),
1732            (
1733                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1734                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1735            ),
1736        ];
1737
1738        let (_, batches) = join_collect(left, right, on, RightSemi).await?;
1739        let expected = [
1740            "+----+----+----+",
1741            "| a1 | b1 | c2 |",
1742            "+----+----+----+",
1743            "| 1  | 4  | 7  |",
1744            "| 2  | 5  | 8  |",
1745            "| 3  | 6  |    |",
1746            "+----+----+----+",
1747        ];
1748        // The output order is important as SMJ preserves sortedness
1749        assert_batches_eq!(expected, &batches);
1750        Ok(())
1751    }
1752
1753    #[tokio::test]
1754    async fn join_right_semi_with_nulls_with_options() -> Result<()> {
1755        let left = build_table_i32_nullable(
1756            ("a1", &vec![Some(3), Some(2), Some(1), Some(0), Some(2)]),
1757            ("b1", &vec![None, Some(5), Some(4), None, Some(5)]),
1758            ("c2", &vec![Some(90), Some(80), Some(70), Some(60), None]),
1759        );
1760        let right = build_table_i32_nullable(
1761            ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]),
1762            ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field
1763            ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field
1764        );
1765        let on = vec![
1766            (
1767                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1768                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1769            ),
1770            (
1771                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1772                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1773            ),
1774        ];
1775
1776        let (_, batches) = join_collect_with_options(
1777            left,
1778            right,
1779            on,
1780            RightSemi,
1781            vec![
1782                SortOptions {
1783                    descending: true,
1784                    nulls_first: false,
1785                };
1786                2
1787            ],
1788            NullEquality::NullEqualsNull,
1789        )
1790        .await?;
1791
1792        let expected = [
1793            "+----+----+----+",
1794            "| a1 | b1 | c2 |",
1795            "+----+----+----+",
1796            "| 3  |    | 9  |",
1797            "| 2  | 5  |    |",
1798            "| 2  | 5  | 8  |",
1799            "| 1  | 4  | 7  |",
1800            "+----+----+----+",
1801        ];
1802        // The output order is important as SMJ preserves sortedness
1803        assert_batches_eq!(expected, &batches);
1804        Ok(())
1805    }
1806
1807    #[tokio::test]
1808    async fn join_right_semi_output_two_batches() -> Result<()> {
1809        let left = build_table(
1810            ("a1", &vec![1, 2, 2, 3]),
1811            ("b1", &vec![4, 5, 5, 6]),
1812            ("c1", &vec![70, 80, 90, 100]),
1813        );
1814        let right = build_table(
1815            ("a1", &vec![1, 2, 2, 3]),
1816            ("b1", &vec![4, 5, 5, 7]),
1817            ("c2", &vec![7, 8, 8, 9]),
1818        );
1819        let on = vec![
1820            (
1821                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
1822                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
1823            ),
1824            (
1825                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1826                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1827            ),
1828        ];
1829
1830        let (_, batches) =
1831            join_collect_batch_size_equals_two(left, right, on, RightSemi).await?;
1832        let expected = [
1833            "+----+----+----+",
1834            "| a1 | b1 | c2 |",
1835            "+----+----+----+",
1836            "| 1  | 4  | 7  |",
1837            "| 2  | 5  | 8  |",
1838            "| 2  | 5  | 8  |",
1839            "+----+----+----+",
1840        ];
1841        assert_eq!(batches.len(), 2);
1842        assert_eq!(batches[0].num_rows(), 2);
1843        assert_eq!(batches[1].num_rows(), 1);
1844        assert_batches_eq!(expected, &batches);
1845        Ok(())
1846    }
1847
1848    #[tokio::test]
1849    async fn join_left_mark() -> Result<()> {
1850        let left = build_table(
1851            ("a1", &vec![1, 2, 2, 3]),
1852            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
1853            ("c1", &vec![7, 8, 8, 9]),
1854        );
1855        let right = build_table(
1856            ("a2", &vec![10, 20, 30, 40]),
1857            ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right
1858            ("c2", &vec![60, 70, 80, 90]),
1859        );
1860        let on = vec![(
1861            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1862            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1863        )];
1864
1865        let (_, batches) = join_collect(left, right, on, LeftMark).await?;
1866        // The output order is important as SMJ preserves sortedness
1867        assert_snapshot!(batches_to_string(&batches), @r#"
1868            +----+----+----+-------+
1869            | a1 | b1 | c1 | mark  |
1870            +----+----+----+-------+
1871            | 1  | 4  | 7  | true  |
1872            | 2  | 5  | 8  | true  |
1873            | 2  | 5  | 8  | true  |
1874            | 3  | 7  | 9  | false |
1875            +----+----+----+-------+
1876            "#);
1877        Ok(())
1878    }
1879
1880    #[tokio::test]
1881    async fn join_with_duplicated_column_names() -> Result<()> {
1882        let left = build_table(
1883            ("a", &vec![1, 2, 3]),
1884            ("b", &vec![4, 5, 7]),
1885            ("c", &vec![7, 8, 9]),
1886        );
1887        let right = build_table(
1888            ("a", &vec![10, 20, 30]),
1889            ("b", &vec![1, 2, 7]),
1890            ("c", &vec![70, 80, 90]),
1891        );
1892        let on = vec![(
1893            // join on a=b so there are duplicate column names on unjoined columns
1894            Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
1895            Arc::new(Column::new_with_schema("b", &right.schema())?) as _,
1896        )];
1897
1898        let (_, batches) = join_collect(left, right, on, Inner).await?;
1899        // The output order is important as SMJ preserves sortedness
1900        assert_snapshot!(batches_to_string(&batches), @r#"
1901            +---+---+---+----+---+----+
1902            | a | b | c | a  | b | c  |
1903            +---+---+---+----+---+----+
1904            | 1 | 4 | 7 | 10 | 1 | 70 |
1905            | 2 | 5 | 8 | 20 | 2 | 80 |
1906            +---+---+---+----+---+----+
1907            "#);
1908        Ok(())
1909    }
1910
1911    #[tokio::test]
1912    async fn join_date32() -> Result<()> {
1913        let left = build_date_table(
1914            ("a1", &vec![1, 2, 3]),
1915            ("b1", &vec![19107, 19108, 19108]), // this has a repetition
1916            ("c1", &vec![7, 8, 9]),
1917        );
1918        let right = build_date_table(
1919            ("a2", &vec![10, 20, 30]),
1920            ("b1", &vec![19107, 19108, 19109]),
1921            ("c2", &vec![70, 80, 90]),
1922        );
1923
1924        let on = vec![(
1925            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1926            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1927        )];
1928
1929        let (_, batches) = join_collect(left, right, on, Inner).await?;
1930
1931        // The output order is important as SMJ preserves sortedness
1932        assert_snapshot!(batches_to_string(&batches), @r#"
1933            +------------+------------+------------+------------+------------+------------+
1934            | a1         | b1         | c1         | a2         | b1         | c2         |
1935            +------------+------------+------------+------------+------------+------------+
1936            | 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |
1937            | 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |
1938            | 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |
1939            +------------+------------+------------+------------+------------+------------+
1940            "#);
1941        Ok(())
1942    }
1943
1944    #[tokio::test]
1945    async fn join_date64() -> Result<()> {
1946        let left = build_date64_table(
1947            ("a1", &vec![1, 2, 3]),
1948            ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition
1949            ("c1", &vec![7, 8, 9]),
1950        );
1951        let right = build_date64_table(
1952            ("a2", &vec![10, 20, 30]),
1953            ("b1", &vec![1650703441000, 1650503441000, 1650903441000]),
1954            ("c2", &vec![70, 80, 90]),
1955        );
1956
1957        let on = vec![(
1958            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1959            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1960        )];
1961
1962        let (_, batches) = join_collect(left, right, on, Inner).await?;
1963
1964        // The output order is important as SMJ preserves sortedness
1965        assert_snapshot!(batches_to_string(&batches), @r#"
1966            +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1967            | a1                      | b1                  | c1                      | a2                      | b1                  | c2                      |
1968            +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1969            | 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |
1970            | 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |
1971            | 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |
1972            +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1973            "#);
1974        Ok(())
1975    }
1976
1977    #[tokio::test]
1978    async fn join_binary() -> Result<()> {
1979        let left = build_binary_table(
1980            (
1981                "a1",
1982                &vec![
1983                    &[0xc0, 0xff, 0xee],
1984                    &[0xde, 0xca, 0xde],
1985                    &[0xfa, 0xca, 0xde],
1986                ],
1987            ),
1988            ("b1", &vec![5, 10, 15]), // this has a repetition
1989            ("c1", &vec![7, 8, 9]),
1990        );
1991        let right = build_binary_table(
1992            (
1993                "a1",
1994                &vec![
1995                    &[0xc0, 0xff, 0xee],
1996                    &[0xde, 0xca, 0xde],
1997                    &[0xfa, 0xca, 0xde],
1998                ],
1999            ),
2000            ("b2", &vec![105, 110, 115]),
2001            ("c2", &vec![70, 80, 90]),
2002        );
2003
2004        let on = vec![(
2005            Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2006            Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2007        )];
2008
2009        let (_, batches) = join_collect(left, right, on, Inner).await?;
2010
2011        // The output order is important as SMJ preserves sortedness
2012        assert_snapshot!(batches_to_string(&batches), @r#"
2013            +--------+----+----+--------+-----+----+
2014            | a1     | b1 | c1 | a1     | b2  | c2 |
2015            +--------+----+----+--------+-----+----+
2016            | c0ffee | 5  | 7  | c0ffee | 105 | 70 |
2017            | decade | 10 | 8  | decade | 110 | 80 |
2018            | facade | 15 | 9  | facade | 115 | 90 |
2019            +--------+----+----+--------+-----+----+
2020            "#);
2021        Ok(())
2022    }
2023
2024    #[tokio::test]
2025    async fn join_fixed_size_binary() -> Result<()> {
2026        let left = build_fixed_size_binary_table(
2027            (
2028                "a1",
2029                &vec![
2030                    &[0xc0, 0xff, 0xee],
2031                    &[0xde, 0xca, 0xde],
2032                    &[0xfa, 0xca, 0xde],
2033                ],
2034            ),
2035            ("b1", &vec![5, 10, 15]), // this has a repetition
2036            ("c1", &vec![7, 8, 9]),
2037        );
2038        let right = build_fixed_size_binary_table(
2039            (
2040                "a1",
2041                &vec![
2042                    &[0xc0, 0xff, 0xee],
2043                    &[0xde, 0xca, 0xde],
2044                    &[0xfa, 0xca, 0xde],
2045                ],
2046            ),
2047            ("b2", &vec![105, 110, 115]),
2048            ("c2", &vec![70, 80, 90]),
2049        );
2050
2051        let on = vec![(
2052            Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2053            Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
2054        )];
2055
2056        let (_, batches) = join_collect(left, right, on, Inner).await?;
2057
2058        // The output order is important as SMJ preserves sortedness
2059        assert_snapshot!(batches_to_string(&batches), @r#"
2060            +--------+----+----+--------+-----+----+
2061            | a1     | b1 | c1 | a1     | b2  | c2 |
2062            +--------+----+----+--------+-----+----+
2063            | c0ffee | 5  | 7  | c0ffee | 105 | 70 |
2064            | decade | 10 | 8  | decade | 110 | 80 |
2065            | facade | 15 | 9  | facade | 115 | 90 |
2066            +--------+----+----+--------+-----+----+
2067            "#);
2068        Ok(())
2069    }
2070
2071    #[tokio::test]
2072    async fn join_left_sort_order() -> Result<()> {
2073        let left = build_table(
2074            ("a1", &vec![0, 1, 2, 3, 4, 5]),
2075            ("b1", &vec![3, 4, 5, 6, 6, 7]),
2076            ("c1", &vec![4, 5, 6, 7, 8, 9]),
2077        );
2078        let right = build_table(
2079            ("a2", &vec![0, 10, 20, 30, 40]),
2080            ("b2", &vec![2, 4, 6, 6, 8]),
2081            ("c2", &vec![50, 60, 70, 80, 90]),
2082        );
2083        let on = vec![(
2084            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2085            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2086        )];
2087
2088        let (_, batches) = join_collect(left, right, on, Left).await?;
2089        assert_snapshot!(batches_to_string(&batches), @r#"
2090            +----+----+----+----+----+----+
2091            | a1 | b1 | c1 | a2 | b2 | c2 |
2092            +----+----+----+----+----+----+
2093            | 0  | 3  | 4  |    |    |    |
2094            | 1  | 4  | 5  | 10 | 4  | 60 |
2095            | 2  | 5  | 6  |    |    |    |
2096            | 3  | 6  | 7  | 20 | 6  | 70 |
2097            | 3  | 6  | 7  | 30 | 6  | 80 |
2098            | 4  | 6  | 8  | 20 | 6  | 70 |
2099            | 4  | 6  | 8  | 30 | 6  | 80 |
2100            | 5  | 7  | 9  |    |    |    |
2101            +----+----+----+----+----+----+
2102            "#);
2103        Ok(())
2104    }
2105
2106    #[tokio::test]
2107    async fn join_right_sort_order() -> Result<()> {
2108        let left = build_table(
2109            ("a1", &vec![0, 1, 2, 3]),
2110            ("b1", &vec![3, 4, 5, 7]),
2111            ("c1", &vec![6, 7, 8, 9]),
2112        );
2113        let right = build_table(
2114            ("a2", &vec![0, 10, 20, 30]),
2115            ("b2", &vec![2, 4, 5, 6]),
2116            ("c2", &vec![60, 70, 80, 90]),
2117        );
2118        let on = vec![(
2119            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2120            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2121        )];
2122
2123        let (_, batches) = join_collect(left, right, on, Right).await?;
2124        assert_snapshot!(batches_to_string(&batches), @r#"
2125            +----+----+----+----+----+----+
2126            | a1 | b1 | c1 | a2 | b2 | c2 |
2127            +----+----+----+----+----+----+
2128            |    |    |    | 0  | 2  | 60 |
2129            | 1  | 4  | 7  | 10 | 4  | 70 |
2130            | 2  | 5  | 8  | 20 | 5  | 80 |
2131            |    |    |    | 30 | 6  | 90 |
2132            +----+----+----+----+----+----+
2133            "#);
2134        Ok(())
2135    }
2136
2137    #[tokio::test]
2138    async fn join_left_multiple_batches() -> Result<()> {
2139        let left_batch_1 = build_table_i32(
2140            ("a1", &vec![0, 1, 2]),
2141            ("b1", &vec![3, 4, 5]),
2142            ("c1", &vec![4, 5, 6]),
2143        );
2144        let left_batch_2 = build_table_i32(
2145            ("a1", &vec![3, 4, 5, 6]),
2146            ("b1", &vec![6, 6, 7, 9]),
2147            ("c1", &vec![7, 8, 9, 9]),
2148        );
2149        let right_batch_1 = build_table_i32(
2150            ("a2", &vec![0, 10, 20]),
2151            ("b2", &vec![2, 4, 6]),
2152            ("c2", &vec![50, 60, 70]),
2153        );
2154        let right_batch_2 = build_table_i32(
2155            ("a2", &vec![30, 40]),
2156            ("b2", &vec![6, 8]),
2157            ("c2", &vec![80, 90]),
2158        );
2159        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
2160        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
2161        let on = vec![(
2162            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2163            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2164        )];
2165
2166        let (_, batches) = join_collect(left, right, on, Left).await?;
2167        assert_snapshot!(batches_to_string(&batches), @r#"
2168            +----+----+----+----+----+----+
2169            | a1 | b1 | c1 | a2 | b2 | c2 |
2170            +----+----+----+----+----+----+
2171            | 0  | 3  | 4  |    |    |    |
2172            | 1  | 4  | 5  | 10 | 4  | 60 |
2173            | 2  | 5  | 6  |    |    |    |
2174            | 3  | 6  | 7  | 20 | 6  | 70 |
2175            | 3  | 6  | 7  | 30 | 6  | 80 |
2176            | 4  | 6  | 8  | 20 | 6  | 70 |
2177            | 4  | 6  | 8  | 30 | 6  | 80 |
2178            | 5  | 7  | 9  |    |    |    |
2179            | 6  | 9  | 9  |    |    |    |
2180            +----+----+----+----+----+----+
2181            "#);
2182        Ok(())
2183    }
2184
2185    #[tokio::test]
2186    async fn join_right_multiple_batches() -> Result<()> {
2187        let right_batch_1 = build_table_i32(
2188            ("a2", &vec![0, 1, 2]),
2189            ("b2", &vec![3, 4, 5]),
2190            ("c2", &vec![4, 5, 6]),
2191        );
2192        let right_batch_2 = build_table_i32(
2193            ("a2", &vec![3, 4, 5, 6]),
2194            ("b2", &vec![6, 6, 7, 9]),
2195            ("c2", &vec![7, 8, 9, 9]),
2196        );
2197        let left_batch_1 = build_table_i32(
2198            ("a1", &vec![0, 10, 20]),
2199            ("b1", &vec![2, 4, 6]),
2200            ("c1", &vec![50, 60, 70]),
2201        );
2202        let left_batch_2 = build_table_i32(
2203            ("a1", &vec![30, 40]),
2204            ("b1", &vec![6, 8]),
2205            ("c1", &vec![80, 90]),
2206        );
2207        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
2208        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
2209        let on = vec![(
2210            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2211            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2212        )];
2213
2214        let (_, batches) = join_collect(left, right, on, Right).await?;
2215        assert_snapshot!(batches_to_string(&batches), @r#"
2216            +----+----+----+----+----+----+
2217            | a1 | b1 | c1 | a2 | b2 | c2 |
2218            +----+----+----+----+----+----+
2219            |    |    |    | 0  | 3  | 4  |
2220            | 10 | 4  | 60 | 1  | 4  | 5  |
2221            |    |    |    | 2  | 5  | 6  |
2222            | 20 | 6  | 70 | 3  | 6  | 7  |
2223            | 30 | 6  | 80 | 3  | 6  | 7  |
2224            | 20 | 6  | 70 | 4  | 6  | 8  |
2225            | 30 | 6  | 80 | 4  | 6  | 8  |
2226            |    |    |    | 5  | 7  | 9  |
2227            |    |    |    | 6  | 9  | 9  |
2228            +----+----+----+----+----+----+
2229            "#);
2230        Ok(())
2231    }
2232
2233    #[tokio::test]
2234    async fn join_full_multiple_batches() -> Result<()> {
2235        let left_batch_1 = build_table_i32(
2236            ("a1", &vec![0, 1, 2]),
2237            ("b1", &vec![3, 4, 5]),
2238            ("c1", &vec![4, 5, 6]),
2239        );
2240        let left_batch_2 = build_table_i32(
2241            ("a1", &vec![3, 4, 5, 6]),
2242            ("b1", &vec![6, 6, 7, 9]),
2243            ("c1", &vec![7, 8, 9, 9]),
2244        );
2245        let right_batch_1 = build_table_i32(
2246            ("a2", &vec![0, 10, 20]),
2247            ("b2", &vec![2, 4, 6]),
2248            ("c2", &vec![50, 60, 70]),
2249        );
2250        let right_batch_2 = build_table_i32(
2251            ("a2", &vec![30, 40]),
2252            ("b2", &vec![6, 8]),
2253            ("c2", &vec![80, 90]),
2254        );
2255        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
2256        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
2257        let on = vec![(
2258            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2259            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2260        )];
2261
2262        let (_, batches) = join_collect(left, right, on, Full).await?;
2263        assert_snapshot!(batches_to_sort_string(&batches), @r#"
2264            +----+----+----+----+----+----+
2265            | a1 | b1 | c1 | a2 | b2 | c2 |
2266            +----+----+----+----+----+----+
2267            |    |    |    | 0  | 2  | 50 |
2268            |    |    |    | 40 | 8  | 90 |
2269            | 0  | 3  | 4  |    |    |    |
2270            | 1  | 4  | 5  | 10 | 4  | 60 |
2271            | 2  | 5  | 6  |    |    |    |
2272            | 3  | 6  | 7  | 20 | 6  | 70 |
2273            | 3  | 6  | 7  | 30 | 6  | 80 |
2274            | 4  | 6  | 8  | 20 | 6  | 70 |
2275            | 4  | 6  | 8  | 30 | 6  | 80 |
2276            | 5  | 7  | 9  |    |    |    |
2277            | 6  | 9  | 9  |    |    |    |
2278            +----+----+----+----+----+----+
2279            "#);
2280        Ok(())
2281    }
2282
2283    #[tokio::test]
2284    async fn overallocation_single_batch_no_spill() -> Result<()> {
2285        let left = build_table(
2286            ("a1", &vec![0, 1, 2, 3, 4, 5]),
2287            ("b1", &vec![1, 2, 3, 4, 5, 6]),
2288            ("c1", &vec![4, 5, 6, 7, 8, 9]),
2289        );
2290        let right = build_table(
2291            ("a2", &vec![0, 10, 20, 30, 40]),
2292            ("b2", &vec![1, 3, 4, 6, 8]),
2293            ("c2", &vec![50, 60, 70, 80, 90]),
2294        );
2295        let on = vec![(
2296            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2297            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2298        )];
2299        let sort_options = vec![SortOptions::default(); on.len()];
2300
2301        let join_types = vec![
2302            Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark,
2303        ];
2304
2305        // Disable DiskManager to prevent spilling
2306        let runtime = RuntimeEnvBuilder::new()
2307            .with_memory_limit(100, 1.0)
2308            .with_disk_manager_builder(
2309                DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2310            )
2311            .build_arc()?;
2312        let session_config = SessionConfig::default().with_batch_size(50);
2313
2314        for join_type in join_types {
2315            let task_ctx = TaskContext::default()
2316                .with_session_config(session_config.clone())
2317                .with_runtime(Arc::clone(&runtime));
2318            let task_ctx = Arc::new(task_ctx);
2319
2320            let join = join_with_options(
2321                Arc::clone(&left),
2322                Arc::clone(&right),
2323                on.clone(),
2324                join_type,
2325                sort_options.clone(),
2326                NullEquality::NullEqualsNothing,
2327            )?;
2328
2329            let stream = join.execute(0, task_ctx)?;
2330            let err = common::collect(stream).await.unwrap_err();
2331
2332            assert_contains!(err.to_string(), "Failed to allocate additional");
2333            assert_contains!(err.to_string(), "SMJStream[0]");
2334            assert_contains!(err.to_string(), "Disk spilling disabled");
2335            assert!(join.metrics().is_some());
2336            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
2337            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
2338            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
2339        }
2340
2341        Ok(())
2342    }
2343
2344    #[tokio::test]
2345    async fn overallocation_multi_batch_no_spill() -> Result<()> {
2346        let left_batch_1 = build_table_i32(
2347            ("a1", &vec![0, 1]),
2348            ("b1", &vec![1, 1]),
2349            ("c1", &vec![4, 5]),
2350        );
2351        let left_batch_2 = build_table_i32(
2352            ("a1", &vec![2, 3]),
2353            ("b1", &vec![1, 1]),
2354            ("c1", &vec![6, 7]),
2355        );
2356        let left_batch_3 = build_table_i32(
2357            ("a1", &vec![4, 5]),
2358            ("b1", &vec![1, 1]),
2359            ("c1", &vec![8, 9]),
2360        );
2361        let right_batch_1 = build_table_i32(
2362            ("a2", &vec![0, 10]),
2363            ("b2", &vec![1, 1]),
2364            ("c2", &vec![50, 60]),
2365        );
2366        let right_batch_2 = build_table_i32(
2367            ("a2", &vec![20, 30]),
2368            ("b2", &vec![1, 1]),
2369            ("c2", &vec![70, 80]),
2370        );
2371        let right_batch_3 =
2372            build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
2373        let left =
2374            build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
2375        let right =
2376            build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
2377        let on = vec![(
2378            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2379            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2380        )];
2381        let sort_options = vec![SortOptions::default(); on.len()];
2382
2383        let join_types = vec![
2384            Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark,
2385        ];
2386
2387        // Disable DiskManager to prevent spilling
2388        let runtime = RuntimeEnvBuilder::new()
2389            .with_memory_limit(100, 1.0)
2390            .with_disk_manager_builder(
2391                DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2392            )
2393            .build_arc()?;
2394        let session_config = SessionConfig::default().with_batch_size(50);
2395
2396        for join_type in join_types {
2397            let task_ctx = TaskContext::default()
2398                .with_session_config(session_config.clone())
2399                .with_runtime(Arc::clone(&runtime));
2400            let task_ctx = Arc::new(task_ctx);
2401            let join = join_with_options(
2402                Arc::clone(&left),
2403                Arc::clone(&right),
2404                on.clone(),
2405                join_type,
2406                sort_options.clone(),
2407                NullEquality::NullEqualsNothing,
2408            )?;
2409
2410            let stream = join.execute(0, task_ctx)?;
2411            let err = common::collect(stream).await.unwrap_err();
2412
2413            assert_contains!(err.to_string(), "Failed to allocate additional");
2414            assert_contains!(err.to_string(), "SMJStream[0]");
2415            assert_contains!(err.to_string(), "Disk spilling disabled");
2416            assert!(join.metrics().is_some());
2417            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
2418            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
2419            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
2420        }
2421
2422        Ok(())
2423    }
2424
2425    #[tokio::test]
2426    async fn overallocation_single_batch_spill() -> Result<()> {
2427        let left = build_table(
2428            ("a1", &vec![0, 1, 2, 3, 4, 5]),
2429            ("b1", &vec![1, 2, 3, 4, 5, 6]),
2430            ("c1", &vec![4, 5, 6, 7, 8, 9]),
2431        );
2432        let right = build_table(
2433            ("a2", &vec![0, 10, 20, 30, 40]),
2434            ("b2", &vec![1, 3, 4, 6, 8]),
2435            ("c2", &vec![50, 60, 70, 80, 90]),
2436        );
2437        let on = vec![(
2438            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2439            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2440        )];
2441        let sort_options = vec![SortOptions::default(); on.len()];
2442
2443        let join_types = [
2444            Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark,
2445        ];
2446
2447        // Enable DiskManager to allow spilling
2448        let runtime = RuntimeEnvBuilder::new()
2449            .with_memory_limit(100, 1.0)
2450            .with_disk_manager_builder(
2451                DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
2452            )
2453            .build_arc()?;
2454
2455        for batch_size in [1, 50] {
2456            let session_config = SessionConfig::default().with_batch_size(batch_size);
2457
2458            for join_type in &join_types {
2459                let task_ctx = TaskContext::default()
2460                    .with_session_config(session_config.clone())
2461                    .with_runtime(Arc::clone(&runtime));
2462                let task_ctx = Arc::new(task_ctx);
2463
2464                let join = join_with_options(
2465                    Arc::clone(&left),
2466                    Arc::clone(&right),
2467                    on.clone(),
2468                    *join_type,
2469                    sort_options.clone(),
2470                    NullEquality::NullEqualsNothing,
2471                )?;
2472
2473                let stream = join.execute(0, task_ctx)?;
2474                let spilled_join_result = common::collect(stream).await.unwrap();
2475
2476                assert!(join.metrics().is_some());
2477                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
2478                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
2479                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
2480
2481                // Run the test with no spill configuration as
2482                let task_ctx_no_spill =
2483                    TaskContext::default().with_session_config(session_config.clone());
2484                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
2485
2486                let join = join_with_options(
2487                    Arc::clone(&left),
2488                    Arc::clone(&right),
2489                    on.clone(),
2490                    *join_type,
2491                    sort_options.clone(),
2492                    NullEquality::NullEqualsNothing,
2493                )?;
2494                let stream = join.execute(0, task_ctx_no_spill)?;
2495                let no_spilled_join_result = common::collect(stream).await.unwrap();
2496
2497                assert!(join.metrics().is_some());
2498                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
2499                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
2500                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
2501                // Compare spilled and non spilled data to check spill logic doesn't corrupt the data
2502                assert_eq!(spilled_join_result, no_spilled_join_result);
2503            }
2504        }
2505
2506        Ok(())
2507    }
2508
2509    #[tokio::test]
2510    async fn overallocation_multi_batch_spill() -> Result<()> {
2511        let left_batch_1 = build_table_i32(
2512            ("a1", &vec![0, 1]),
2513            ("b1", &vec![1, 1]),
2514            ("c1", &vec![4, 5]),
2515        );
2516        let left_batch_2 = build_table_i32(
2517            ("a1", &vec![2, 3]),
2518            ("b1", &vec![1, 1]),
2519            ("c1", &vec![6, 7]),
2520        );
2521        let left_batch_3 = build_table_i32(
2522            ("a1", &vec![4, 5]),
2523            ("b1", &vec![1, 1]),
2524            ("c1", &vec![8, 9]),
2525        );
2526        let right_batch_1 = build_table_i32(
2527            ("a2", &vec![0, 10]),
2528            ("b2", &vec![1, 1]),
2529            ("c2", &vec![50, 60]),
2530        );
2531        let right_batch_2 = build_table_i32(
2532            ("a2", &vec![20, 30]),
2533            ("b2", &vec![1, 1]),
2534            ("c2", &vec![70, 80]),
2535        );
2536        let right_batch_3 =
2537            build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
2538        let left =
2539            build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
2540        let right =
2541            build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
2542        let on = vec![(
2543            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2544            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2545        )];
2546        let sort_options = vec![SortOptions::default(); on.len()];
2547
2548        let join_types = [
2549            Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark,
2550        ];
2551
2552        // Enable DiskManager to allow spilling
2553        let runtime = RuntimeEnvBuilder::new()
2554            .with_memory_limit(500, 1.0)
2555            .with_disk_manager_builder(
2556                DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
2557            )
2558            .build_arc()?;
2559
2560        for batch_size in [1, 50] {
2561            let session_config = SessionConfig::default().with_batch_size(batch_size);
2562
2563            for join_type in &join_types {
2564                let task_ctx = TaskContext::default()
2565                    .with_session_config(session_config.clone())
2566                    .with_runtime(Arc::clone(&runtime));
2567                let task_ctx = Arc::new(task_ctx);
2568                let join = join_with_options(
2569                    Arc::clone(&left),
2570                    Arc::clone(&right),
2571                    on.clone(),
2572                    *join_type,
2573                    sort_options.clone(),
2574                    NullEquality::NullEqualsNothing,
2575                )?;
2576
2577                let stream = join.execute(0, task_ctx)?;
2578                let spilled_join_result = common::collect(stream).await.unwrap();
2579                assert!(join.metrics().is_some());
2580                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
2581                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
2582                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
2583
2584                // Run the test with no spill configuration as
2585                let task_ctx_no_spill =
2586                    TaskContext::default().with_session_config(session_config.clone());
2587                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
2588
2589                let join = join_with_options(
2590                    Arc::clone(&left),
2591                    Arc::clone(&right),
2592                    on.clone(),
2593                    *join_type,
2594                    sort_options.clone(),
2595                    NullEquality::NullEqualsNothing,
2596                )?;
2597                let stream = join.execute(0, task_ctx_no_spill)?;
2598                let no_spilled_join_result = common::collect(stream).await.unwrap();
2599
2600                assert!(join.metrics().is_some());
2601                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
2602                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
2603                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
2604                // Compare spilled and non spilled data to check spill logic doesn't corrupt the data
2605                assert_eq!(spilled_join_result, no_spilled_join_result);
2606            }
2607        }
2608
2609        Ok(())
2610    }
2611
2612    fn build_joined_record_batches() -> Result<JoinedRecordBatches> {
2613        let schema = Arc::new(Schema::new(vec![
2614            Field::new("a", DataType::Int32, true),
2615            Field::new("b", DataType::Int32, true),
2616            Field::new("x", DataType::Int32, true),
2617            Field::new("y", DataType::Int32, true),
2618        ]));
2619
2620        let mut batches = JoinedRecordBatches {
2621            batches: vec![],
2622            filter_mask: BooleanBuilder::new(),
2623            row_indices: UInt64Builder::new(),
2624            batch_ids: vec![],
2625        };
2626
2627        // Insert already prejoined non-filtered rows
2628        batches.batches.push(RecordBatch::try_new(
2629            Arc::clone(&schema),
2630            vec![
2631                Arc::new(Int32Array::from(vec![1, 1])),
2632                Arc::new(Int32Array::from(vec![10, 10])),
2633                Arc::new(Int32Array::from(vec![1, 1])),
2634                Arc::new(Int32Array::from(vec![11, 9])),
2635            ],
2636        )?);
2637
2638        batches.batches.push(RecordBatch::try_new(
2639            Arc::clone(&schema),
2640            vec![
2641                Arc::new(Int32Array::from(vec![1])),
2642                Arc::new(Int32Array::from(vec![11])),
2643                Arc::new(Int32Array::from(vec![1])),
2644                Arc::new(Int32Array::from(vec![12])),
2645            ],
2646        )?);
2647
2648        batches.batches.push(RecordBatch::try_new(
2649            Arc::clone(&schema),
2650            vec![
2651                Arc::new(Int32Array::from(vec![1, 1])),
2652                Arc::new(Int32Array::from(vec![12, 12])),
2653                Arc::new(Int32Array::from(vec![1, 1])),
2654                Arc::new(Int32Array::from(vec![11, 13])),
2655            ],
2656        )?);
2657
2658        batches.batches.push(RecordBatch::try_new(
2659            Arc::clone(&schema),
2660            vec![
2661                Arc::new(Int32Array::from(vec![1])),
2662                Arc::new(Int32Array::from(vec![13])),
2663                Arc::new(Int32Array::from(vec![1])),
2664                Arc::new(Int32Array::from(vec![12])),
2665            ],
2666        )?);
2667
2668        batches.batches.push(RecordBatch::try_new(
2669            Arc::clone(&schema),
2670            vec![
2671                Arc::new(Int32Array::from(vec![1, 1])),
2672                Arc::new(Int32Array::from(vec![14, 14])),
2673                Arc::new(Int32Array::from(vec![1, 1])),
2674                Arc::new(Int32Array::from(vec![12, 11])),
2675            ],
2676        )?);
2677
2678        let streamed_indices = vec![0, 0];
2679        batches.batch_ids.extend(vec![0; streamed_indices.len()]);
2680        batches
2681            .row_indices
2682            .extend(&UInt64Array::from(streamed_indices));
2683
2684        let streamed_indices = vec![1];
2685        batches.batch_ids.extend(vec![0; streamed_indices.len()]);
2686        batches
2687            .row_indices
2688            .extend(&UInt64Array::from(streamed_indices));
2689
2690        let streamed_indices = vec![0, 0];
2691        batches.batch_ids.extend(vec![1; streamed_indices.len()]);
2692        batches
2693            .row_indices
2694            .extend(&UInt64Array::from(streamed_indices));
2695
2696        let streamed_indices = vec![0];
2697        batches.batch_ids.extend(vec![2; streamed_indices.len()]);
2698        batches
2699            .row_indices
2700            .extend(&UInt64Array::from(streamed_indices));
2701
2702        let streamed_indices = vec![0, 0];
2703        batches.batch_ids.extend(vec![3; streamed_indices.len()]);
2704        batches
2705            .row_indices
2706            .extend(&UInt64Array::from(streamed_indices));
2707
2708        batches
2709            .filter_mask
2710            .extend(&BooleanArray::from(vec![true, false]));
2711        batches.filter_mask.extend(&BooleanArray::from(vec![true]));
2712        batches
2713            .filter_mask
2714            .extend(&BooleanArray::from(vec![false, true]));
2715        batches.filter_mask.extend(&BooleanArray::from(vec![false]));
2716        batches
2717            .filter_mask
2718            .extend(&BooleanArray::from(vec![false, false]));
2719
2720        Ok(batches)
2721    }
2722
2723    #[tokio::test]
2724    async fn test_left_outer_join_filtered_mask() -> Result<()> {
2725        let mut joined_batches = build_joined_record_batches()?;
2726        let schema = joined_batches.batches.first().unwrap().schema();
2727
2728        let output = concat_batches(&schema, &joined_batches.batches)?;
2729        let out_mask = joined_batches.filter_mask.finish();
2730        let out_indices = joined_batches.row_indices.finish();
2731
2732        assert_eq!(
2733            get_corrected_filter_mask(
2734                Left,
2735                &UInt64Array::from(vec![0]),
2736                &[0usize],
2737                &BooleanArray::from(vec![true]),
2738                output.num_rows()
2739            )
2740            .unwrap(),
2741            BooleanArray::from(vec![
2742                true, false, false, false, false, false, false, false
2743            ])
2744        );
2745
2746        assert_eq!(
2747            get_corrected_filter_mask(
2748                Left,
2749                &UInt64Array::from(vec![0]),
2750                &[0usize],
2751                &BooleanArray::from(vec![false]),
2752                output.num_rows()
2753            )
2754            .unwrap(),
2755            BooleanArray::from(vec![
2756                false, false, false, false, false, false, false, false
2757            ])
2758        );
2759
2760        assert_eq!(
2761            get_corrected_filter_mask(
2762                Left,
2763                &UInt64Array::from(vec![0, 0]),
2764                &[0usize; 2],
2765                &BooleanArray::from(vec![true, true]),
2766                output.num_rows()
2767            )
2768            .unwrap(),
2769            BooleanArray::from(vec![
2770                true, true, false, false, false, false, false, false
2771            ])
2772        );
2773
2774        assert_eq!(
2775            get_corrected_filter_mask(
2776                Left,
2777                &UInt64Array::from(vec![0, 0, 0]),
2778                &[0usize; 3],
2779                &BooleanArray::from(vec![true, true, true]),
2780                output.num_rows()
2781            )
2782            .unwrap(),
2783            BooleanArray::from(vec![true, true, true, false, false, false, false, false])
2784        );
2785
2786        assert_eq!(
2787            get_corrected_filter_mask(
2788                Left,
2789                &UInt64Array::from(vec![0, 0, 0]),
2790                &[0usize; 3],
2791                &BooleanArray::from(vec![true, false, true]),
2792                output.num_rows()
2793            )
2794            .unwrap(),
2795            BooleanArray::from(vec![
2796                Some(true),
2797                None,
2798                Some(true),
2799                Some(false),
2800                Some(false),
2801                Some(false),
2802                Some(false),
2803                Some(false)
2804            ])
2805        );
2806
2807        assert_eq!(
2808            get_corrected_filter_mask(
2809                Left,
2810                &UInt64Array::from(vec![0, 0, 0]),
2811                &[0usize; 3],
2812                &BooleanArray::from(vec![false, false, true]),
2813                output.num_rows()
2814            )
2815            .unwrap(),
2816            BooleanArray::from(vec![
2817                None,
2818                None,
2819                Some(true),
2820                Some(false),
2821                Some(false),
2822                Some(false),
2823                Some(false),
2824                Some(false)
2825            ])
2826        );
2827
2828        assert_eq!(
2829            get_corrected_filter_mask(
2830                Left,
2831                &UInt64Array::from(vec![0, 0, 0]),
2832                &[0usize; 3],
2833                &BooleanArray::from(vec![false, true, true]),
2834                output.num_rows()
2835            )
2836            .unwrap(),
2837            BooleanArray::from(vec![
2838                None,
2839                Some(true),
2840                Some(true),
2841                Some(false),
2842                Some(false),
2843                Some(false),
2844                Some(false),
2845                Some(false)
2846            ])
2847        );
2848
2849        assert_eq!(
2850            get_corrected_filter_mask(
2851                Left,
2852                &UInt64Array::from(vec![0, 0, 0]),
2853                &[0usize; 3],
2854                &BooleanArray::from(vec![false, false, false]),
2855                output.num_rows()
2856            )
2857            .unwrap(),
2858            BooleanArray::from(vec![
2859                None,
2860                None,
2861                Some(false),
2862                Some(false),
2863                Some(false),
2864                Some(false),
2865                Some(false),
2866                Some(false)
2867            ])
2868        );
2869
2870        let corrected_mask = get_corrected_filter_mask(
2871            Left,
2872            &out_indices,
2873            &joined_batches.batch_ids,
2874            &out_mask,
2875            output.num_rows(),
2876        )
2877        .unwrap();
2878
2879        assert_eq!(
2880            corrected_mask,
2881            BooleanArray::from(vec![
2882                Some(true),
2883                None,
2884                Some(true),
2885                None,
2886                Some(true),
2887                Some(false),
2888                None,
2889                Some(false)
2890            ])
2891        );
2892
2893        let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
2894
2895        assert_snapshot!(batches_to_string(&[filtered_rb]), @r#"
2896                +---+----+---+----+
2897                | a | b  | x | y  |
2898                +---+----+---+----+
2899                | 1 | 10 | 1 | 11 |
2900                | 1 | 11 | 1 | 12 |
2901                | 1 | 12 | 1 | 13 |
2902                +---+----+---+----+
2903            "#);
2904
2905        // output null rows
2906
2907        let null_mask = arrow::compute::not(&corrected_mask)?;
2908        assert_eq!(
2909            null_mask,
2910            BooleanArray::from(vec![
2911                Some(false),
2912                None,
2913                Some(false),
2914                None,
2915                Some(false),
2916                Some(true),
2917                None,
2918                Some(true)
2919            ])
2920        );
2921
2922        let null_joined_batch = filter_record_batch(&output, &null_mask)?;
2923
2924        assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#"
2925                +---+----+---+----+
2926                | a | b  | x | y  |
2927                +---+----+---+----+
2928                | 1 | 13 | 1 | 12 |
2929                | 1 | 14 | 1 | 11 |
2930                +---+----+---+----+
2931            "#);
2932        Ok(())
2933    }
2934
2935    #[tokio::test]
2936    async fn test_semi_join_filtered_mask() -> Result<()> {
2937        for join_type in [LeftSemi, RightSemi] {
2938            let mut joined_batches = build_joined_record_batches()?;
2939            let schema = joined_batches.batches.first().unwrap().schema();
2940
2941            let output = concat_batches(&schema, &joined_batches.batches)?;
2942            let out_mask = joined_batches.filter_mask.finish();
2943            let out_indices = joined_batches.row_indices.finish();
2944
2945            assert_eq!(
2946                get_corrected_filter_mask(
2947                    join_type,
2948                    &UInt64Array::from(vec![0]),
2949                    &[0usize],
2950                    &BooleanArray::from(vec![true]),
2951                    output.num_rows()
2952                )
2953                .unwrap(),
2954                BooleanArray::from(vec![true])
2955            );
2956
2957            assert_eq!(
2958                get_corrected_filter_mask(
2959                    join_type,
2960                    &UInt64Array::from(vec![0]),
2961                    &[0usize],
2962                    &BooleanArray::from(vec![false]),
2963                    output.num_rows()
2964                )
2965                .unwrap(),
2966                BooleanArray::from(vec![None])
2967            );
2968
2969            assert_eq!(
2970                get_corrected_filter_mask(
2971                    join_type,
2972                    &UInt64Array::from(vec![0, 0]),
2973                    &[0usize; 2],
2974                    &BooleanArray::from(vec![true, true]),
2975                    output.num_rows()
2976                )
2977                .unwrap(),
2978                BooleanArray::from(vec![Some(true), None])
2979            );
2980
2981            assert_eq!(
2982                get_corrected_filter_mask(
2983                    join_type,
2984                    &UInt64Array::from(vec![0, 0, 0]),
2985                    &[0usize; 3],
2986                    &BooleanArray::from(vec![true, true, true]),
2987                    output.num_rows()
2988                )
2989                .unwrap(),
2990                BooleanArray::from(vec![Some(true), None, None])
2991            );
2992
2993            assert_eq!(
2994                get_corrected_filter_mask(
2995                    join_type,
2996                    &UInt64Array::from(vec![0, 0, 0]),
2997                    &[0usize; 3],
2998                    &BooleanArray::from(vec![true, false, true]),
2999                    output.num_rows()
3000                )
3001                .unwrap(),
3002                BooleanArray::from(vec![Some(true), None, None])
3003            );
3004
3005            assert_eq!(
3006                get_corrected_filter_mask(
3007                    join_type,
3008                    &UInt64Array::from(vec![0, 0, 0]),
3009                    &[0usize; 3],
3010                    &BooleanArray::from(vec![false, false, true]),
3011                    output.num_rows()
3012                )
3013                .unwrap(),
3014                BooleanArray::from(vec![None, None, Some(true),])
3015            );
3016
3017            assert_eq!(
3018                get_corrected_filter_mask(
3019                    join_type,
3020                    &UInt64Array::from(vec![0, 0, 0]),
3021                    &[0usize; 3],
3022                    &BooleanArray::from(vec![false, true, true]),
3023                    output.num_rows()
3024                )
3025                .unwrap(),
3026                BooleanArray::from(vec![None, Some(true), None])
3027            );
3028
3029            assert_eq!(
3030                get_corrected_filter_mask(
3031                    join_type,
3032                    &UInt64Array::from(vec![0, 0, 0]),
3033                    &[0usize; 3],
3034                    &BooleanArray::from(vec![false, false, false]),
3035                    output.num_rows()
3036                )
3037                .unwrap(),
3038                BooleanArray::from(vec![None, None, None])
3039            );
3040
3041            let corrected_mask = get_corrected_filter_mask(
3042                join_type,
3043                &out_indices,
3044                &joined_batches.batch_ids,
3045                &out_mask,
3046                output.num_rows(),
3047            )
3048            .unwrap();
3049
3050            assert_eq!(
3051                corrected_mask,
3052                BooleanArray::from(vec![
3053                    Some(true),
3054                    None,
3055                    Some(true),
3056                    None,
3057                    Some(true),
3058                    None,
3059                    None,
3060                    None
3061                ])
3062            );
3063
3064            let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
3065
3066            assert_batches_eq!(
3067                &[
3068                    "+---+----+---+----+",
3069                    "| a | b  | x | y  |",
3070                    "+---+----+---+----+",
3071                    "| 1 | 10 | 1 | 11 |",
3072                    "| 1 | 11 | 1 | 12 |",
3073                    "| 1 | 12 | 1 | 13 |",
3074                    "+---+----+---+----+",
3075                ],
3076                &[filtered_rb]
3077            );
3078
3079            // output null rows
3080            let null_mask = arrow::compute::not(&corrected_mask)?;
3081            assert_eq!(
3082                null_mask,
3083                BooleanArray::from(vec![
3084                    Some(false),
3085                    None,
3086                    Some(false),
3087                    None,
3088                    Some(false),
3089                    None,
3090                    None,
3091                    None
3092                ])
3093            );
3094
3095            let null_joined_batch = filter_record_batch(&output, &null_mask)?;
3096
3097            assert_batches_eq!(
3098                &[
3099                    "+---+---+---+---+",
3100                    "| a | b | x | y |",
3101                    "+---+---+---+---+",
3102                    "+---+---+---+---+",
3103                ],
3104                &[null_joined_batch]
3105            );
3106        }
3107        Ok(())
3108    }
3109
3110    #[tokio::test]
3111    async fn test_anti_join_filtered_mask() -> Result<()> {
3112        for join_type in [LeftAnti, RightAnti] {
3113            let mut joined_batches = build_joined_record_batches()?;
3114            let schema = joined_batches.batches.first().unwrap().schema();
3115
3116            let output = concat_batches(&schema, &joined_batches.batches)?;
3117            let out_mask = joined_batches.filter_mask.finish();
3118            let out_indices = joined_batches.row_indices.finish();
3119
3120            assert_eq!(
3121                get_corrected_filter_mask(
3122                    join_type,
3123                    &UInt64Array::from(vec![0]),
3124                    &[0usize],
3125                    &BooleanArray::from(vec![true]),
3126                    1
3127                )
3128                .unwrap(),
3129                BooleanArray::from(vec![None])
3130            );
3131
3132            assert_eq!(
3133                get_corrected_filter_mask(
3134                    join_type,
3135                    &UInt64Array::from(vec![0]),
3136                    &[0usize],
3137                    &BooleanArray::from(vec![false]),
3138                    1
3139                )
3140                .unwrap(),
3141                BooleanArray::from(vec![Some(true)])
3142            );
3143
3144            assert_eq!(
3145                get_corrected_filter_mask(
3146                    join_type,
3147                    &UInt64Array::from(vec![0, 0]),
3148                    &[0usize; 2],
3149                    &BooleanArray::from(vec![true, true]),
3150                    2
3151                )
3152                .unwrap(),
3153                BooleanArray::from(vec![None, None])
3154            );
3155
3156            assert_eq!(
3157                get_corrected_filter_mask(
3158                    join_type,
3159                    &UInt64Array::from(vec![0, 0, 0]),
3160                    &[0usize; 3],
3161                    &BooleanArray::from(vec![true, true, true]),
3162                    3
3163                )
3164                .unwrap(),
3165                BooleanArray::from(vec![None, None, None])
3166            );
3167
3168            assert_eq!(
3169                get_corrected_filter_mask(
3170                    join_type,
3171                    &UInt64Array::from(vec![0, 0, 0]),
3172                    &[0usize; 3],
3173                    &BooleanArray::from(vec![true, false, true]),
3174                    3
3175                )
3176                .unwrap(),
3177                BooleanArray::from(vec![None, None, None])
3178            );
3179
3180            assert_eq!(
3181                get_corrected_filter_mask(
3182                    join_type,
3183                    &UInt64Array::from(vec![0, 0, 0]),
3184                    &[0usize; 3],
3185                    &BooleanArray::from(vec![false, false, true]),
3186                    3
3187                )
3188                .unwrap(),
3189                BooleanArray::from(vec![None, None, None])
3190            );
3191
3192            assert_eq!(
3193                get_corrected_filter_mask(
3194                    join_type,
3195                    &UInt64Array::from(vec![0, 0, 0]),
3196                    &[0usize; 3],
3197                    &BooleanArray::from(vec![false, true, true]),
3198                    3
3199                )
3200                .unwrap(),
3201                BooleanArray::from(vec![None, None, None])
3202            );
3203
3204            assert_eq!(
3205                get_corrected_filter_mask(
3206                    join_type,
3207                    &UInt64Array::from(vec![0, 0, 0]),
3208                    &[0usize; 3],
3209                    &BooleanArray::from(vec![false, false, false]),
3210                    3
3211                )
3212                .unwrap(),
3213                BooleanArray::from(vec![None, None, Some(true)])
3214            );
3215
3216            let corrected_mask = get_corrected_filter_mask(
3217                join_type,
3218                &out_indices,
3219                &joined_batches.batch_ids,
3220                &out_mask,
3221                output.num_rows(),
3222            )
3223            .unwrap();
3224
3225            assert_eq!(
3226                corrected_mask,
3227                BooleanArray::from(vec![
3228                    None,
3229                    None,
3230                    None,
3231                    None,
3232                    None,
3233                    Some(true),
3234                    None,
3235                    Some(true)
3236                ])
3237            );
3238
3239            let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
3240
3241            allow_duplicates! {
3242                assert_snapshot!(batches_to_string(&[filtered_rb]), @r#"
3243                    +---+----+---+----+
3244                    | a | b  | x | y  |
3245                    +---+----+---+----+
3246                    | 1 | 13 | 1 | 12 |
3247                    | 1 | 14 | 1 | 11 |
3248                    +---+----+---+----+
3249            "#);
3250            }
3251
3252            // output null rows
3253            let null_mask = arrow::compute::not(&corrected_mask)?;
3254            assert_eq!(
3255                null_mask,
3256                BooleanArray::from(vec![
3257                    None,
3258                    None,
3259                    None,
3260                    None,
3261                    None,
3262                    Some(false),
3263                    None,
3264                    Some(false),
3265                ])
3266            );
3267
3268            let null_joined_batch = filter_record_batch(&output, &null_mask)?;
3269
3270            allow_duplicates! {
3271                assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#"
3272                        +---+---+---+---+
3273                        | a | b | x | y |
3274                        +---+---+---+---+
3275                        +---+---+---+---+
3276                "#);
3277            }
3278        }
3279        Ok(())
3280    }
3281
3282    /// Returns the column names on the schema
3283    fn columns(schema: &Schema) -> Vec<String> {
3284        schema.fields().iter().map(|f| f.name().clone()).collect()
3285    }
3286}