datafusion_physical_plan/joins/
symmetric_hash_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the symmetric hash join algorithm with range-based
19//! data pruning to join two (potentially infinite) streams.
20//!
21//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate
22//! output ordering) and produces the join output according to the given join
23//! type and other options.
24//!
25//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations
26//! for both its children.
27
28use std::any::Any;
29use std::fmt::{self, Debug};
30use std::mem::{size_of, size_of_val};
31use std::sync::Arc;
32use std::task::{Context, Poll};
33use std::vec;
34
35use crate::common::SharedMemoryReservation;
36use crate::execution_plan::{boundedness_from_children, emission_type_from_children};
37use crate::joins::stream_join_utils::{
38    calculate_filter_expr_intervals, combine_two_batches,
39    convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
40    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
41    PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
42};
43use crate::joins::utils::{
44    apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
45    check_join_is_valid, equal_rows_arr, symmetric_join_output_partitioning, update_hash,
46    BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn,
47    JoinOnRef, NoopBatchTransformer, StatefulStreamResult,
48};
49use crate::projection::{
50    join_allows_pushdown, join_table_borders, new_join_children,
51    physical_to_column_exprs, update_join_filter, update_join_on, ProjectionExec,
52};
53use crate::{
54    joins::StreamJoinPartitionMode,
55    metrics::{ExecutionPlanMetricsSet, MetricsSet},
56    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
57    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
58};
59
60use arrow::array::{
61    ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
62    UInt64Array,
63};
64use arrow::compute::concat_batches;
65use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
66use arrow::record_batch::RecordBatch;
67use datafusion_common::hash_utils::create_hashes;
68use datafusion_common::utils::bisect;
69use datafusion_common::{
70    internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, Result,
71};
72use datafusion_execution::memory_pool::MemoryConsumer;
73use datafusion_execution::TaskContext;
74use datafusion_expr::interval_arithmetic::Interval;
75use datafusion_physical_expr::equivalence::join_equivalence_properties;
76use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
77use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef};
78use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
79
80use ahash::RandomState;
81use futures::{ready, Stream, StreamExt};
82use parking_lot::Mutex;
83
84const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
85
86/// A symmetric hash join with range conditions is when both streams are hashed on the
87/// join key and the resulting hash tables are used to join the streams.
88/// The join is considered symmetric because the hash table is built on the join keys from both
89/// streams, and the matching of rows is based on the values of the join keys in both streams.
90/// This type of join is efficient in streaming context as it allows for fast lookups in the hash
91/// table, rather than having to scan through one or both of the streams to find matching rows, also it
92/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions),
93/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming
94/// data without any memory issues.
95///
96/// For each input stream, create a hash table.
97///   - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets.
98///   - Test if input is equal to a predefined set of other inputs.
99///   - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch].
100///   - Try to prune other side (probe) with new [RecordBatch].
101///   - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.),
102///     output the [RecordBatch] when a pruning happens or at the end of the data.
103///
104///
105/// ``` text
106///                        +-------------------------+
107///                        |                         |
108///   left stream ---------|  Left OneSideHashJoiner |---+
109///                        |                         |   |
110///                        +-------------------------+   |
111///                                                      |
112///                                                      |--------- Joined output
113///                                                      |
114///                        +-------------------------+   |
115///                        |                         |   |
116///  right stream ---------| Right OneSideHashJoiner |---+
117///                        |                         |
118///                        +-------------------------+
119///
120/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic
121/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range.
122///
123///
124///               PROBE SIDE          BUILD SIDE
125///                 BUFFER              BUFFER
126///             +-------------+     +------------+
127///             |             |     |            |    Unjoinable
128///             |             |     |            |    Range
129///             |             |     |            |
130///             |             |  |---------------------------------
131///             |             |  |  |            |
132///             |             |  |  |            |
133///             |             | /   |            |
134///             |             | |   |            |
135///             |             | |   |            |
136///             |             | |   |            |
137///             |             | |   |            |
138///             |             | |   |            |    Joinable
139///             |             |/    |            |    Range
140///             |             ||    |            |
141///             |+-----------+||    |            |
142///             || Record    ||     |            |
143///             || Batch     ||     |            |
144///             |+-----------+||    |            |
145///             +-------------+\    +------------+
146///                             |
147///                             \
148///                              |---------------------------------
149///
150///  This happens when range conditions are provided on sorted columns. E.g.
151///
152///        SELECT * FROM left_table, right_table
153///        ON
154///          left_key = right_key AND
155///          left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR
156///
157/// or
158///       SELECT * FROM left_table, right_table
159///        ON
160///          left_key = right_key AND
161///          left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10
162///
163/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to
164/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the
165/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios)
166/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning
167/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" ,
168/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending)
169/// than that can be dropped from the inner buffer.
170/// ```
171#[derive(Debug, Clone)]
172pub struct SymmetricHashJoinExec {
173    /// Left side stream
174    pub(crate) left: Arc<dyn ExecutionPlan>,
175    /// Right side stream
176    pub(crate) right: Arc<dyn ExecutionPlan>,
177    /// Set of common columns used to join on
178    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
179    /// Filters applied when finding matching rows
180    pub(crate) filter: Option<JoinFilter>,
181    /// How the join is performed
182    pub(crate) join_type: JoinType,
183    /// Shares the `RandomState` for the hashing algorithm
184    random_state: RandomState,
185    /// Execution metrics
186    metrics: ExecutionPlanMetricsSet,
187    /// Information of index and left / right placement of columns
188    column_indices: Vec<ColumnIndex>,
189    /// Defines the null equality for the join.
190    pub(crate) null_equality: NullEquality,
191    /// Left side sort expression(s)
192    pub(crate) left_sort_exprs: Option<LexOrdering>,
193    /// Right side sort expression(s)
194    pub(crate) right_sort_exprs: Option<LexOrdering>,
195    /// Partition Mode
196    mode: StreamJoinPartitionMode,
197    /// Cache holding plan properties like equivalences, output partitioning etc.
198    cache: PlanProperties,
199}
200
201impl SymmetricHashJoinExec {
202    /// Tries to create a new [SymmetricHashJoinExec].
203    /// # Error
204    /// This function errors when:
205    /// - It is not possible to join the left and right sides on keys `on`, or
206    /// - It fails to construct `SortedFilterExpr`s, or
207    /// - It fails to create the [ExprIntervalGraph].
208    #[allow(clippy::too_many_arguments)]
209    pub fn try_new(
210        left: Arc<dyn ExecutionPlan>,
211        right: Arc<dyn ExecutionPlan>,
212        on: JoinOn,
213        filter: Option<JoinFilter>,
214        join_type: &JoinType,
215        null_equality: NullEquality,
216        left_sort_exprs: Option<LexOrdering>,
217        right_sort_exprs: Option<LexOrdering>,
218        mode: StreamJoinPartitionMode,
219    ) -> Result<Self> {
220        let left_schema = left.schema();
221        let right_schema = right.schema();
222
223        // Error out if no "on" constraints are given:
224        if on.is_empty() {
225            return plan_err!(
226                "On constraints in SymmetricHashJoinExec should be non-empty"
227            );
228        }
229
230        // Check if the join is valid with the given on constraints:
231        check_join_is_valid(&left_schema, &right_schema, &on)?;
232
233        // Build the join schema from the left and right schemas:
234        let (schema, column_indices) =
235            build_join_schema(&left_schema, &right_schema, join_type);
236
237        // Initialize the random state for the join operation:
238        let random_state = RandomState::with_seeds(0, 0, 0, 0);
239        let schema = Arc::new(schema);
240        let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?;
241        Ok(SymmetricHashJoinExec {
242            left,
243            right,
244            on,
245            filter,
246            join_type: *join_type,
247            random_state,
248            metrics: ExecutionPlanMetricsSet::new(),
249            column_indices,
250            null_equality,
251            left_sort_exprs,
252            right_sort_exprs,
253            mode,
254            cache,
255        })
256    }
257
258    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
259    fn compute_properties(
260        left: &Arc<dyn ExecutionPlan>,
261        right: &Arc<dyn ExecutionPlan>,
262        schema: SchemaRef,
263        join_type: JoinType,
264        join_on: JoinOnRef,
265    ) -> Result<PlanProperties> {
266        // Calculate equivalence properties:
267        let eq_properties = join_equivalence_properties(
268            left.equivalence_properties().clone(),
269            right.equivalence_properties().clone(),
270            &join_type,
271            schema,
272            &[false, false],
273            // Has alternating probe side
274            None,
275            join_on,
276        )?;
277
278        let output_partitioning =
279            symmetric_join_output_partitioning(left, right, &join_type)?;
280
281        Ok(PlanProperties::new(
282            eq_properties,
283            output_partitioning,
284            emission_type_from_children([left, right]),
285            boundedness_from_children([left, right]),
286        ))
287    }
288
289    /// left stream
290    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
291        &self.left
292    }
293
294    /// right stream
295    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
296        &self.right
297    }
298
299    /// Set of common columns used to join on
300    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
301        &self.on
302    }
303
304    /// Filters applied before join output
305    pub fn filter(&self) -> Option<&JoinFilter> {
306        self.filter.as_ref()
307    }
308
309    /// How the join is performed
310    pub fn join_type(&self) -> &JoinType {
311        &self.join_type
312    }
313
314    /// Get null_equality
315    pub fn null_equality(&self) -> NullEquality {
316        self.null_equality
317    }
318
319    /// Get partition mode
320    pub fn partition_mode(&self) -> StreamJoinPartitionMode {
321        self.mode
322    }
323
324    /// Get left_sort_exprs
325    pub fn left_sort_exprs(&self) -> Option<&LexOrdering> {
326        self.left_sort_exprs.as_ref()
327    }
328
329    /// Get right_sort_exprs
330    pub fn right_sort_exprs(&self) -> Option<&LexOrdering> {
331        self.right_sort_exprs.as_ref()
332    }
333
334    /// Check if order information covers every column in the filter expression.
335    pub fn check_if_order_information_available(&self) -> Result<bool> {
336        if let Some(filter) = self.filter() {
337            let left = self.left();
338            if let Some(left_ordering) = left.output_ordering() {
339                let right = self.right();
340                if let Some(right_ordering) = right.output_ordering() {
341                    let left_convertible = convert_sort_expr_with_filter_schema(
342                        &JoinSide::Left,
343                        filter,
344                        &left.schema(),
345                        &left_ordering[0],
346                    )?
347                    .is_some();
348                    let right_convertible = convert_sort_expr_with_filter_schema(
349                        &JoinSide::Right,
350                        filter,
351                        &right.schema(),
352                        &right_ordering[0],
353                    )?
354                    .is_some();
355                    return Ok(left_convertible && right_convertible);
356                }
357            }
358        }
359        Ok(false)
360    }
361}
362
363impl DisplayAs for SymmetricHashJoinExec {
364    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
365        match t {
366            DisplayFormatType::Default | DisplayFormatType::Verbose => {
367                let display_filter = self.filter.as_ref().map_or_else(
368                    || "".to_string(),
369                    |f| format!(", filter={}", f.expression()),
370                );
371                let on = self
372                    .on
373                    .iter()
374                    .map(|(c1, c2)| format!("({c1}, {c2})"))
375                    .collect::<Vec<String>>()
376                    .join(", ");
377                write!(
378                    f,
379                    "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
380                    self.mode, self.join_type, on, display_filter
381                )
382            }
383            DisplayFormatType::TreeRender => {
384                let on = self
385                    .on
386                    .iter()
387                    .map(|(c1, c2)| {
388                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
389                    })
390                    .collect::<Vec<String>>()
391                    .join(", ");
392
393                writeln!(f, "mode={:?}", self.mode)?;
394                if *self.join_type() != JoinType::Inner {
395                    writeln!(f, "join_type={:?}", self.join_type)?;
396                }
397                writeln!(f, "on={on}")
398            }
399        }
400    }
401}
402
403impl ExecutionPlan for SymmetricHashJoinExec {
404    fn name(&self) -> &'static str {
405        "SymmetricHashJoinExec"
406    }
407
408    fn as_any(&self) -> &dyn Any {
409        self
410    }
411
412    fn properties(&self) -> &PlanProperties {
413        &self.cache
414    }
415
416    fn required_input_distribution(&self) -> Vec<Distribution> {
417        match self.mode {
418            StreamJoinPartitionMode::Partitioned => {
419                let (left_expr, right_expr) = self
420                    .on
421                    .iter()
422                    .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
423                    .unzip();
424                vec![
425                    Distribution::HashPartitioned(left_expr),
426                    Distribution::HashPartitioned(right_expr),
427                ]
428            }
429            StreamJoinPartitionMode::SinglePartition => {
430                vec![Distribution::SinglePartition, Distribution::SinglePartition]
431            }
432        }
433    }
434
435    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
436        vec![
437            self.left_sort_exprs
438                .as_ref()
439                .map(|e| OrderingRequirements::from(e.clone())),
440            self.right_sort_exprs
441                .as_ref()
442                .map(|e| OrderingRequirements::from(e.clone())),
443        ]
444    }
445
446    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
447        vec![&self.left, &self.right]
448    }
449
450    fn with_new_children(
451        self: Arc<Self>,
452        children: Vec<Arc<dyn ExecutionPlan>>,
453    ) -> Result<Arc<dyn ExecutionPlan>> {
454        Ok(Arc::new(SymmetricHashJoinExec::try_new(
455            Arc::clone(&children[0]),
456            Arc::clone(&children[1]),
457            self.on.clone(),
458            self.filter.clone(),
459            &self.join_type,
460            self.null_equality,
461            self.left_sort_exprs.clone(),
462            self.right_sort_exprs.clone(),
463            self.mode,
464        )?))
465    }
466
467    fn metrics(&self) -> Option<MetricsSet> {
468        Some(self.metrics.clone_inner())
469    }
470
471    fn statistics(&self) -> Result<Statistics> {
472        // TODO stats: it is not possible in general to know the output size of joins
473        Ok(Statistics::new_unknown(&self.schema()))
474    }
475
476    fn execute(
477        &self,
478        partition: usize,
479        context: Arc<TaskContext>,
480    ) -> Result<SendableRecordBatchStream> {
481        let left_partitions = self.left.output_partitioning().partition_count();
482        let right_partitions = self.right.output_partitioning().partition_count();
483        if left_partitions != right_partitions {
484            return internal_err!(
485                "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
486                 consider using RepartitionExec"
487            );
488        }
489        // If `filter_state` and `filter` are both present, then calculate sorted
490        // filter expressions for both sides, and build an expression graph.
491        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
492            self.left_sort_exprs(),
493            self.right_sort_exprs(),
494            &self.filter,
495        ) {
496            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
497                let (left, right, graph) = prepare_sorted_exprs(
498                    filter,
499                    &self.left,
500                    &self.right,
501                    left_sort_exprs,
502                    right_sort_exprs,
503                )?;
504                (Some(left), Some(right), Some(graph))
505            }
506            // If `filter_state` or `filter` is not present, then return None
507            // for all three values:
508            _ => (None, None, None),
509        };
510
511        let (on_left, on_right) = self.on.iter().cloned().unzip();
512
513        let left_side_joiner =
514            OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
515        let right_side_joiner =
516            OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
517
518        let left_stream = self.left.execute(partition, Arc::clone(&context))?;
519
520        let right_stream = self.right.execute(partition, Arc::clone(&context))?;
521
522        let batch_size = context.session_config().batch_size();
523        let enforce_batch_size_in_joins =
524            context.session_config().enforce_batch_size_in_joins();
525
526        let reservation = Arc::new(Mutex::new(
527            MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
528                .register(context.memory_pool()),
529        ));
530        if let Some(g) = graph.as_ref() {
531            reservation.lock().try_grow(g.size())?;
532        }
533
534        if enforce_batch_size_in_joins {
535            Ok(Box::pin(SymmetricHashJoinStream {
536                left_stream,
537                right_stream,
538                schema: self.schema(),
539                filter: self.filter.clone(),
540                join_type: self.join_type,
541                random_state: self.random_state.clone(),
542                left: left_side_joiner,
543                right: right_side_joiner,
544                column_indices: self.column_indices.clone(),
545                metrics: StreamJoinMetrics::new(partition, &self.metrics),
546                graph,
547                left_sorted_filter_expr,
548                right_sorted_filter_expr,
549                null_equality: self.null_equality,
550                state: SHJStreamState::PullRight,
551                reservation,
552                batch_transformer: BatchSplitter::new(batch_size),
553            }))
554        } else {
555            Ok(Box::pin(SymmetricHashJoinStream {
556                left_stream,
557                right_stream,
558                schema: self.schema(),
559                filter: self.filter.clone(),
560                join_type: self.join_type,
561                random_state: self.random_state.clone(),
562                left: left_side_joiner,
563                right: right_side_joiner,
564                column_indices: self.column_indices.clone(),
565                metrics: StreamJoinMetrics::new(partition, &self.metrics),
566                graph,
567                left_sorted_filter_expr,
568                right_sorted_filter_expr,
569                null_equality: self.null_equality,
570                state: SHJStreamState::PullRight,
571                reservation,
572                batch_transformer: NoopBatchTransformer::new(),
573            }))
574        }
575    }
576
577    /// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done,
578    /// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan.
579    /// Otherwise, it returns None.
580    fn try_swapping_with_projection(
581        &self,
582        projection: &ProjectionExec,
583    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
584        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
585        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
586        else {
587            return Ok(None);
588        };
589
590        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
591            self.left().schema().fields().len(),
592            &projection_as_columns,
593        );
594
595        if !join_allows_pushdown(
596            &projection_as_columns,
597            &self.schema(),
598            far_right_left_col_ind,
599            far_left_right_col_ind,
600        ) {
601            return Ok(None);
602        }
603
604        let Some(new_on) = update_join_on(
605            &projection_as_columns[0..=far_right_left_col_ind as _],
606            &projection_as_columns[far_left_right_col_ind as _..],
607            self.on(),
608            self.left().schema().fields().len(),
609        ) else {
610            return Ok(None);
611        };
612
613        let new_filter = if let Some(filter) = self.filter() {
614            match update_join_filter(
615                &projection_as_columns[0..=far_right_left_col_ind as _],
616                &projection_as_columns[far_left_right_col_ind as _..],
617                filter,
618                self.left().schema().fields().len(),
619            ) {
620                Some(updated_filter) => Some(updated_filter),
621                None => return Ok(None),
622            }
623        } else {
624            None
625        };
626
627        let (new_left, new_right) = new_join_children(
628            &projection_as_columns,
629            far_right_left_col_ind,
630            far_left_right_col_ind,
631            self.left(),
632            self.right(),
633        )?;
634
635        SymmetricHashJoinExec::try_new(
636            Arc::new(new_left),
637            Arc::new(new_right),
638            new_on,
639            new_filter,
640            self.join_type(),
641            self.null_equality(),
642            self.right().output_ordering().cloned(),
643            self.left().output_ordering().cloned(),
644            self.partition_mode(),
645        )
646        .map(|e| Some(Arc::new(e) as _))
647    }
648}
649
650/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
651struct SymmetricHashJoinStream<T> {
652    /// Input streams
653    left_stream: SendableRecordBatchStream,
654    right_stream: SendableRecordBatchStream,
655    /// Input schema
656    schema: Arc<Schema>,
657    /// join filter
658    filter: Option<JoinFilter>,
659    /// type of the join
660    join_type: JoinType,
661    // left hash joiner
662    left: OneSideHashJoiner,
663    /// right hash joiner
664    right: OneSideHashJoiner,
665    /// Information of index and left / right placement of columns
666    column_indices: Vec<ColumnIndex>,
667    // Expression graph for range pruning.
668    graph: Option<ExprIntervalGraph>,
669    // Left globally sorted filter expr
670    left_sorted_filter_expr: Option<SortedFilterExpr>,
671    // Right globally sorted filter expr
672    right_sorted_filter_expr: Option<SortedFilterExpr>,
673    /// Random state used for hashing initialization
674    random_state: RandomState,
675    /// Defines the null equality for the join.
676    null_equality: NullEquality,
677    /// Metrics
678    metrics: StreamJoinMetrics,
679    /// Memory reservation
680    reservation: SharedMemoryReservation,
681    /// State machine for input execution
682    state: SHJStreamState,
683    /// Transforms the output batch before returning.
684    batch_transformer: T,
685}
686
687impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
688    for SymmetricHashJoinStream<T>
689{
690    fn schema(&self) -> SchemaRef {
691        Arc::clone(&self.schema)
692    }
693}
694
695impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> {
696    type Item = Result<RecordBatch>;
697
698    fn poll_next(
699        mut self: std::pin::Pin<&mut Self>,
700        cx: &mut Context<'_>,
701    ) -> Poll<Option<Self::Item>> {
702        self.poll_next_impl(cx)
703    }
704}
705
706/// Determine the pruning length for `buffer`.
707///
708/// This function evaluates the build side filter expression, converts the
709/// result into an array and determines the pruning length by performing a
710/// binary search on the array.
711///
712/// # Arguments
713///
714/// * `buffer`: The record batch to be pruned.
715/// * `build_side_filter_expr`: The filter expression on the build side used
716///   to determine the pruning length.
717///
718/// # Returns
719///
720/// A [Result] object that contains the pruning length. The function will return
721/// an error if
722/// - there is an issue evaluating the build side filter expression;
723/// - there is an issue converting the build side filter expression into an array
724fn determine_prune_length(
725    buffer: &RecordBatch,
726    build_side_filter_expr: &SortedFilterExpr,
727) -> Result<usize> {
728    let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
729    let interval = build_side_filter_expr.interval();
730    // Evaluate the build side filter expression and convert it into an array
731    let batch_arr = origin_sorted_expr
732        .expr
733        .evaluate(buffer)?
734        .into_array(buffer.num_rows())?;
735
736    // Get the lower or upper interval based on the sort direction
737    let target = if origin_sorted_expr.options.descending {
738        interval.upper().clone()
739    } else {
740        interval.lower().clone()
741    };
742
743    // Perform binary search on the array to determine the length of the record batch to be pruned
744    bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
745}
746
747/// This method determines if the result of the join should be produced in the final step or not.
748///
749/// # Arguments
750///
751/// * `build_side` - Enum indicating the side of the join used as the build side.
752/// * `join_type` - Enum indicating the type of join to be performed.
753///
754/// # Returns
755///
756/// A boolean indicating whether the result of the join should be produced in the final step or not.
757/// The result will be true if the build side is JoinSide::Left and the join type is one of
758/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi.
759/// If the build side is JoinSide::Right, the result will be true if the join type
760/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi.
761fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
762    if build_side == JoinSide::Left {
763        matches!(
764            join_type,
765            JoinType::Left
766                | JoinType::LeftAnti
767                | JoinType::Full
768                | JoinType::LeftSemi
769                | JoinType::LeftMark
770        )
771    } else {
772        matches!(
773            join_type,
774            JoinType::Right
775                | JoinType::RightAnti
776                | JoinType::Full
777                | JoinType::RightSemi
778                | JoinType::RightMark
779        )
780    }
781}
782
783/// Calculate indices by join type.
784///
785/// This method returns a tuple of two arrays: build and probe indices.
786/// The length of both arrays will be the same.
787///
788/// # Arguments
789///
790/// * `build_side`: Join side which defines the build side.
791/// * `prune_length`: Length of the prune data.
792/// * `visited_rows`: Hash set of visited rows of the build side.
793/// * `deleted_offset`: Deleted offset of the build side.
794/// * `join_type`: The type of join to be performed.
795///
796/// # Returns
797///
798/// A tuple of two arrays of primitive types representing the build and probe indices.
799fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
800    build_side: JoinSide,
801    prune_length: usize,
802    visited_rows: &HashSet<usize>,
803    deleted_offset: usize,
804    join_type: JoinType,
805) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
806where
807    NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
808{
809    // Store the result in a tuple
810    let result = match (build_side, join_type) {
811        // For a mark join we “mark” each build‐side row with a dummy 0 in the probe‐side index
812        // if it ever matched. For example, if
813        //
814        // prune_length = 5
815        // deleted_offset = 0
816        // visited_rows = {1, 3}
817        //
818        // then we produce:
819        //
820        // build_indices = [0, 1, 2, 3, 4]
821        // probe_indices = [None, Some(0), None, Some(0), None]
822        //
823        // Example: for each build row i in [0..5):
824        //   – We always output its own index i in `build_indices`
825        //   – We output `Some(0)` in `probe_indices[i]` if row i was ever visited, else `None`
826        (JoinSide::Left, JoinType::LeftMark) => {
827            let build_indices = (0..prune_length)
828                .map(L::Native::from_usize)
829                .collect::<PrimitiveArray<L>>();
830            let probe_indices = (0..prune_length)
831                .map(|idx| {
832                    // For mark join we output a dummy index 0 to indicate the row had a match
833                    visited_rows
834                        .contains(&(idx + deleted_offset))
835                        .then_some(R::Native::from_usize(0).unwrap())
836                })
837                .collect();
838            (build_indices, probe_indices)
839        }
840        (JoinSide::Right, JoinType::RightMark) => {
841            let build_indices = (0..prune_length)
842                .map(L::Native::from_usize)
843                .collect::<PrimitiveArray<L>>();
844            let probe_indices = (0..prune_length)
845                .map(|idx| {
846                    // For mark join we output a dummy index 0 to indicate the row had a match
847                    visited_rows
848                        .contains(&(idx + deleted_offset))
849                        .then_some(R::Native::from_usize(0).unwrap())
850                })
851                .collect();
852            (build_indices, probe_indices)
853        }
854        // In the case of `Left` or `Right` join, or `Full` join, get the anti indices
855        (JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
856        | (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
857        | (_, JoinType::Full) => {
858            let build_unmatched_indices =
859                get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
860            let mut builder =
861                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
862            builder.append_nulls(build_unmatched_indices.len());
863            let probe_indices = builder.finish();
864            (build_unmatched_indices, probe_indices)
865        }
866        // In the case of `LeftSemi` or `RightSemi` join, get the semi indices
867        (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
868            let build_unmatched_indices =
869                get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
870            let mut builder =
871                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
872            builder.append_nulls(build_unmatched_indices.len());
873            let probe_indices = builder.finish();
874            (build_unmatched_indices, probe_indices)
875        }
876        // The case of other join types is not considered
877        _ => unreachable!(),
878    };
879    Ok(result)
880}
881
882/// This function produces unmatched record results based on the build side,
883/// join type and other parameters.
884///
885/// The method uses first `prune_length` rows from the build side input buffer
886/// to produce results.
887///
888/// # Arguments
889///
890/// * `output_schema` - The schema of the final output record batch.
891/// * `prune_length` - The length of the determined prune length.
892/// * `probe_schema` - The schema of the probe [RecordBatch].
893/// * `join_type` - The type of join to be performed.
894/// * `column_indices` - Indices of columns that are being joined.
895///
896/// # Returns
897///
898/// * `Option<RecordBatch>` - The final output record batch if required, otherwise [None].
899pub(crate) fn build_side_determined_results(
900    build_hash_joiner: &OneSideHashJoiner,
901    output_schema: &SchemaRef,
902    prune_length: usize,
903    probe_schema: SchemaRef,
904    join_type: JoinType,
905    column_indices: &[ColumnIndex],
906) -> Result<Option<RecordBatch>> {
907    // Check if we need to produce a result in the final output:
908    if prune_length > 0
909        && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
910    {
911        // Calculate the indices for build and probe sides based on join type and build side:
912        let (build_indices, probe_indices) = calculate_indices_by_join_type(
913            build_hash_joiner.build_side,
914            prune_length,
915            &build_hash_joiner.visited_rows,
916            build_hash_joiner.deleted_offset,
917            join_type,
918        )?;
919
920        // Create an empty probe record batch:
921        let empty_probe_batch = RecordBatch::new_empty(probe_schema);
922        // Build the final result from the indices of build and probe sides:
923        build_batch_from_indices(
924            output_schema.as_ref(),
925            &build_hash_joiner.input_buffer,
926            &empty_probe_batch,
927            &build_indices,
928            &probe_indices,
929            column_indices,
930            build_hash_joiner.build_side,
931        )
932        .map(|batch| (batch.num_rows() > 0).then_some(batch))
933    } else {
934        // If we don't need to produce a result, return None
935        Ok(None)
936    }
937}
938
939/// This method performs a join between the build side input buffer and the probe side batch.
940///
941/// # Arguments
942///
943/// * `build_hash_joiner` - Build side hash joiner
944/// * `probe_hash_joiner` - Probe side hash joiner
945/// * `schema` - A reference to the schema of the output record batch.
946/// * `join_type` - The type of join to be performed.
947/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
948/// * `filter` - An optional filter on the join condition.
949/// * `probe_batch` - The second record batch to be joined.
950/// * `column_indices` - An array of columns to be selected for the result of the join.
951/// * `random_state` - The random state for the join.
952/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
953///
954/// # Returns
955///
956/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
957/// If the join type is one of the above four, the function will return [None].
958#[allow(clippy::too_many_arguments)]
959pub(crate) fn join_with_probe_batch(
960    build_hash_joiner: &mut OneSideHashJoiner,
961    probe_hash_joiner: &mut OneSideHashJoiner,
962    schema: &SchemaRef,
963    join_type: JoinType,
964    filter: Option<&JoinFilter>,
965    probe_batch: &RecordBatch,
966    column_indices: &[ColumnIndex],
967    random_state: &RandomState,
968    null_equality: NullEquality,
969) -> Result<Option<RecordBatch>> {
970    if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
971        return Ok(None);
972    }
973    let (build_indices, probe_indices) = lookup_join_hashmap(
974        &build_hash_joiner.hashmap,
975        &build_hash_joiner.input_buffer,
976        probe_batch,
977        &build_hash_joiner.on,
978        &probe_hash_joiner.on,
979        random_state,
980        null_equality,
981        &mut build_hash_joiner.hashes_buffer,
982        Some(build_hash_joiner.deleted_offset),
983    )?;
984
985    let (build_indices, probe_indices) = if let Some(filter) = filter {
986        apply_join_filter_to_indices(
987            &build_hash_joiner.input_buffer,
988            probe_batch,
989            build_indices,
990            probe_indices,
991            filter,
992            build_hash_joiner.build_side,
993            None,
994        )?
995    } else {
996        (build_indices, probe_indices)
997    };
998
999    if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
1000        record_visited_indices(
1001            &mut build_hash_joiner.visited_rows,
1002            build_hash_joiner.deleted_offset,
1003            &build_indices,
1004        );
1005    }
1006    if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
1007        record_visited_indices(
1008            &mut probe_hash_joiner.visited_rows,
1009            probe_hash_joiner.offset,
1010            &probe_indices,
1011        );
1012    }
1013    if matches!(
1014        join_type,
1015        JoinType::LeftAnti
1016            | JoinType::RightAnti
1017            | JoinType::LeftSemi
1018            | JoinType::LeftMark
1019            | JoinType::RightSemi
1020            | JoinType::RightMark
1021    ) {
1022        Ok(None)
1023    } else {
1024        build_batch_from_indices(
1025            schema,
1026            &build_hash_joiner.input_buffer,
1027            probe_batch,
1028            &build_indices,
1029            &probe_indices,
1030            column_indices,
1031            build_hash_joiner.build_side,
1032        )
1033        .map(|batch| (batch.num_rows() > 0).then_some(batch))
1034    }
1035}
1036
1037/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential
1038/// hash collisions.
1039///
1040/// # Arguments
1041///
1042/// * `build_hashmap` - hashmap collected from build side data.
1043/// * `build_batch` - Build side record batch.
1044/// * `probe_batch` - Probe side record batch.
1045/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join.
1046/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
1047/// * `random_state` - The random state for the join.
1048/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
1049/// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
1050/// * `deleted_offset` - deleted offset for build side data.
1051///
1052/// # Returns
1053///
1054/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side,
1055/// matched by join key columns.
1056#[allow(clippy::too_many_arguments)]
1057fn lookup_join_hashmap(
1058    build_hashmap: &PruningJoinHashMap,
1059    build_batch: &RecordBatch,
1060    probe_batch: &RecordBatch,
1061    build_on: &[PhysicalExprRef],
1062    probe_on: &[PhysicalExprRef],
1063    random_state: &RandomState,
1064    null_equality: NullEquality,
1065    hashes_buffer: &mut Vec<u64>,
1066    deleted_offset: Option<usize>,
1067) -> Result<(UInt64Array, UInt32Array)> {
1068    let keys_values = probe_on
1069        .iter()
1070        .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))
1071        .collect::<Result<Vec<_>>>()?;
1072    let build_join_values = build_on
1073        .iter()
1074        .map(|c| c.evaluate(build_batch)?.into_array(build_batch.num_rows()))
1075        .collect::<Result<Vec<_>>>()?;
1076
1077    hashes_buffer.clear();
1078    hashes_buffer.resize(probe_batch.num_rows(), 0);
1079    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1080
1081    // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm
1082    // will return build indices for each probe row in a reverse order as such:
1083    // Build Indices: [5, 4, 3]
1084    // Probe Indices: [1, 1, 1]
1085    //
1086    // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side.
1087    // Let's consider probe rows [0,1] as an example:
1088    //
1089    // When the probe iteration sequence is reversed, the following pairings can be derived:
1090    //
1091    // For probe row 1:
1092    //     (5, 1)
1093    //     (4, 1)
1094    //     (3, 1)
1095    //
1096    // For probe row 0:
1097    //     (5, 0)
1098    //     (4, 0)
1099    //     (3, 0)
1100    //
1101    // After reversing both sets of indices, we obtain reversed indices:
1102    //
1103    //     (3,0)
1104    //     (4,0)
1105    //     (5,0)
1106    //     (3,1)
1107    //     (4,1)
1108    //     (5,1)
1109    //
1110    // With this approach, the lexicographic order on both the probe side and the build side is preserved.
1111    let (mut matched_probe, mut matched_build) = build_hashmap.get_matched_indices(
1112        Box::new(hash_values.iter().enumerate().rev()),
1113        deleted_offset,
1114    );
1115
1116    matched_probe.reverse();
1117    matched_build.reverse();
1118
1119    let build_indices: UInt64Array = matched_build.into();
1120    let probe_indices: UInt32Array = matched_probe.into();
1121
1122    let (build_indices, probe_indices) = equal_rows_arr(
1123        &build_indices,
1124        &probe_indices,
1125        &build_join_values,
1126        &keys_values,
1127        null_equality,
1128    )?;
1129
1130    Ok((build_indices, probe_indices))
1131}
1132
1133pub struct OneSideHashJoiner {
1134    /// Build side
1135    build_side: JoinSide,
1136    /// Input record batch buffer
1137    pub input_buffer: RecordBatch,
1138    /// Columns from the side
1139    pub(crate) on: Vec<PhysicalExprRef>,
1140    /// Hashmap
1141    pub(crate) hashmap: PruningJoinHashMap,
1142    /// Reuse the hashes buffer
1143    pub(crate) hashes_buffer: Vec<u64>,
1144    /// Matched rows
1145    pub(crate) visited_rows: HashSet<usize>,
1146    /// Offset
1147    pub(crate) offset: usize,
1148    /// Deleted offset
1149    pub(crate) deleted_offset: usize,
1150}
1151
1152impl OneSideHashJoiner {
1153    pub fn size(&self) -> usize {
1154        let mut size = 0;
1155        size += size_of_val(self);
1156        size += size_of_val(&self.build_side);
1157        size += self.input_buffer.get_array_memory_size();
1158        size += size_of_val(&self.on);
1159        size += self.hashmap.size();
1160        size += self.hashes_buffer.capacity() * size_of::<u64>();
1161        size += self.visited_rows.capacity() * size_of::<usize>();
1162        size += size_of_val(&self.offset);
1163        size += size_of_val(&self.deleted_offset);
1164        size
1165    }
1166    pub fn new(
1167        build_side: JoinSide,
1168        on: Vec<PhysicalExprRef>,
1169        schema: SchemaRef,
1170    ) -> Self {
1171        Self {
1172            build_side,
1173            input_buffer: RecordBatch::new_empty(schema),
1174            on,
1175            hashmap: PruningJoinHashMap::with_capacity(0),
1176            hashes_buffer: vec![],
1177            visited_rows: HashSet::new(),
1178            offset: 0,
1179            deleted_offset: 0,
1180        }
1181    }
1182
1183    /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch.
1184    ///
1185    /// # Arguments
1186    ///
1187    /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer
1188    /// * `random_state` - The random state used to hash values
1189    ///
1190    /// # Returns
1191    ///
1192    /// Returns a [Result] encapsulating any intermediate errors.
1193    pub(crate) fn update_internal_state(
1194        &mut self,
1195        batch: &RecordBatch,
1196        random_state: &RandomState,
1197    ) -> Result<()> {
1198        // Merge the incoming batch with the existing input buffer:
1199        self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
1200        // Resize the hashes buffer to the number of rows in the incoming batch:
1201        self.hashes_buffer.resize(batch.num_rows(), 0);
1202        // Get allocation_info before adding the item
1203        // Update the hashmap with the join key values and hashes of the incoming batch:
1204        update_hash(
1205            &self.on,
1206            batch,
1207            &mut self.hashmap,
1208            self.offset,
1209            random_state,
1210            &mut self.hashes_buffer,
1211            self.deleted_offset,
1212            false,
1213        )?;
1214        Ok(())
1215    }
1216
1217    /// Calculate prune length.
1218    ///
1219    /// # Arguments
1220    ///
1221    /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression..
1222    /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression.
1223    /// * `graph` - A mutable reference to the physical expression graph.
1224    ///
1225    /// # Returns
1226    ///
1227    /// A Result object that contains the pruning length.
1228    pub(crate) fn calculate_prune_length_with_probe_batch(
1229        &mut self,
1230        build_side_sorted_filter_expr: &mut SortedFilterExpr,
1231        probe_side_sorted_filter_expr: &mut SortedFilterExpr,
1232        graph: &mut ExprIntervalGraph,
1233    ) -> Result<usize> {
1234        // Return early if the input buffer is empty:
1235        if self.input_buffer.num_rows() == 0 {
1236            return Ok(0);
1237        }
1238        // Process the build and probe side sorted filter expressions if both are present:
1239        // Collect the sorted filter expressions into a vector of (node_index, interval) tuples:
1240        let mut filter_intervals = vec![];
1241        for expr in [
1242            &build_side_sorted_filter_expr,
1243            &probe_side_sorted_filter_expr,
1244        ] {
1245            filter_intervals.push((expr.node_index(), expr.interval().clone()))
1246        }
1247        // Update the physical expression graph using the join filter intervals:
1248        graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?;
1249        // Extract the new join filter interval for the build side:
1250        let calculated_build_side_interval = filter_intervals.remove(0).1;
1251        // If the intervals have not changed, return early without pruning:
1252        if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
1253            return Ok(0);
1254        }
1255        // Update the build side interval and determine the pruning length:
1256        build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
1257
1258        determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
1259    }
1260
1261    pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
1262        // Prune the hash values:
1263        self.hashmap.prune_hash_values(
1264            prune_length,
1265            self.deleted_offset as u64,
1266            HASHMAP_SHRINK_SCALE_FACTOR,
1267        );
1268        // Remove pruned rows from the visited rows set:
1269        for row in self.deleted_offset..(self.deleted_offset + prune_length) {
1270            self.visited_rows.remove(&row);
1271        }
1272        // Update the input buffer after pruning:
1273        self.input_buffer = self
1274            .input_buffer
1275            .slice(prune_length, self.input_buffer.num_rows() - prune_length);
1276        // Increment the deleted offset:
1277        self.deleted_offset += prune_length;
1278        Ok(())
1279    }
1280}
1281
1282/// `SymmetricHashJoinStream` manages incremental join operations between two
1283/// streams. Unlike traditional join approaches that need to scan one side of
1284/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates
1285/// more dynamic join operations by working with streams as they emit data. This
1286/// approach allows for more efficient processing, particularly in scenarios
1287/// where waiting for complete data materialization is not feasible or optimal.
1288/// The trait provides a framework for handling various states of such a join
1289/// process, ensuring that join logic is efficiently executed as data becomes
1290/// available from either stream.
1291///
1292/// This implementation performs eager joins of data from two different asynchronous
1293/// streams, typically referred to as left and right streams. The implementation
1294/// provides a comprehensive set of methods to control and execute the join
1295/// process, leveraging the states defined in `SHJStreamState`. Methods are
1296/// primarily focused on asynchronously fetching data batches from each stream,
1297/// processing them, and managing transitions between various states of the join.
1298///
1299/// This implementations use a state machine approach to navigate different
1300/// stages of the join operation, handling data from both streams and determining
1301/// when the join completes.
1302///
1303/// State Transitions:
1304/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
1305///   - In `fetch_next_from_left_stream`, when fetching a batch from the left stream:
1306///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
1307///       processing the batch.
1308///     - On error (`Some(Err(e))`), the error is returned, and the state remains
1309///       unchanged.
1310///     - On no data (`None`), state changes to `LeftExhausted`, returning `Continue`
1311///       to proceed with the join process.
1312/// - From `PullRight` to `PullLeft` or `RightExhausted`:
1313///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
1314///     - If a batch is available, state changes to `PullLeft` for processing.
1315///     - On error, the error is returned without changing the state.
1316///     - If right stream is exhausted (`None`), state transitions to `RightExhausted`,
1317///       with a `Continue` result.
1318/// - Handling `RightExhausted` and `LeftExhausted`:
1319///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios
1320///     when streams are exhausted:
1321///     - They attempt to continue processing with the other stream.
1322///     - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`.
1323/// - Transition to `BothExhausted { final_result: true }`:
1324///   - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are
1325///     exhausted, indicating completion of processing and availability of final results.
1326impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
1327    /// Implements the main polling logic for the join stream.
1328    ///
1329    /// This method continuously checks the state of the join stream and
1330    /// acts accordingly by delegating the handling to appropriate sub-methods
1331    /// depending on the current state.
1332    ///
1333    /// # Arguments
1334    ///
1335    /// * `cx` - A context that facilitates cooperative non-blocking execution within a task.
1336    ///
1337    /// # Returns
1338    ///
1339    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a `RecordBatch` or None.
1340    fn poll_next_impl(
1341        &mut self,
1342        cx: &mut Context<'_>,
1343    ) -> Poll<Option<Result<RecordBatch>>> {
1344        loop {
1345            match self.batch_transformer.next() {
1346                None => {
1347                    let result = match self.state() {
1348                        SHJStreamState::PullRight => {
1349                            ready!(self.fetch_next_from_right_stream(cx))
1350                        }
1351                        SHJStreamState::PullLeft => {
1352                            ready!(self.fetch_next_from_left_stream(cx))
1353                        }
1354                        SHJStreamState::RightExhausted => {
1355                            ready!(self.handle_right_stream_end(cx))
1356                        }
1357                        SHJStreamState::LeftExhausted => {
1358                            ready!(self.handle_left_stream_end(cx))
1359                        }
1360                        SHJStreamState::BothExhausted {
1361                            final_result: false,
1362                        } => self.prepare_for_final_results_after_exhaustion(),
1363                        SHJStreamState::BothExhausted { final_result: true } => {
1364                            return Poll::Ready(None);
1365                        }
1366                    };
1367
1368                    match result? {
1369                        StatefulStreamResult::Ready(None) => {
1370                            return Poll::Ready(None);
1371                        }
1372                        StatefulStreamResult::Ready(Some(batch)) => {
1373                            self.batch_transformer.set_batch(batch);
1374                        }
1375                        _ => {}
1376                    }
1377                }
1378                Some((batch, _)) => {
1379                    self.metrics.output_batches.add(1);
1380                    return self
1381                        .metrics
1382                        .baseline_metrics
1383                        .record_poll(Poll::Ready(Some(Ok(batch))));
1384                }
1385            }
1386        }
1387    }
1388    /// Asynchronously pulls the next batch from the right stream.
1389    ///
1390    /// This default implementation checks for the next value in the right stream.
1391    /// If a batch is found, the state is switched to `PullLeft`, and the batch handling
1392    /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`.
1393    ///
1394    /// # Returns
1395    ///
1396    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1397    fn fetch_next_from_right_stream(
1398        &mut self,
1399        cx: &mut Context<'_>,
1400    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1401        match ready!(self.right_stream().poll_next_unpin(cx)) {
1402            Some(Ok(batch)) => {
1403                if batch.num_rows() == 0 {
1404                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1405                }
1406                self.set_state(SHJStreamState::PullLeft);
1407                Poll::Ready(self.process_batch_from_right(batch))
1408            }
1409            Some(Err(e)) => Poll::Ready(Err(e)),
1410            None => {
1411                self.set_state(SHJStreamState::RightExhausted);
1412                Poll::Ready(Ok(StatefulStreamResult::Continue))
1413            }
1414        }
1415    }
1416
1417    /// Asynchronously pulls the next batch from the left stream.
1418    ///
1419    /// This default implementation checks for the next value in the left stream.
1420    /// If a batch is found, the state is switched to `PullRight`, and the batch handling
1421    /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`.
1422    ///
1423    /// # Returns
1424    ///
1425    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1426    fn fetch_next_from_left_stream(
1427        &mut self,
1428        cx: &mut Context<'_>,
1429    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1430        match ready!(self.left_stream().poll_next_unpin(cx)) {
1431            Some(Ok(batch)) => {
1432                if batch.num_rows() == 0 {
1433                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1434                }
1435                self.set_state(SHJStreamState::PullRight);
1436                Poll::Ready(self.process_batch_from_left(batch))
1437            }
1438            Some(Err(e)) => Poll::Ready(Err(e)),
1439            None => {
1440                self.set_state(SHJStreamState::LeftExhausted);
1441                Poll::Ready(Ok(StatefulStreamResult::Continue))
1442            }
1443        }
1444    }
1445
1446    /// Asynchronously handles the scenario when the right stream is exhausted.
1447    ///
1448    /// In this default implementation, when the right stream is exhausted, it attempts
1449    /// to pull from the left stream. If a batch is found in the left stream, it delegates
1450    /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set
1451    /// to indicate both streams are exhausted without final results yet.
1452    ///
1453    /// # Returns
1454    ///
1455    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1456    fn handle_right_stream_end(
1457        &mut self,
1458        cx: &mut Context<'_>,
1459    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1460        match ready!(self.left_stream().poll_next_unpin(cx)) {
1461            Some(Ok(batch)) => {
1462                if batch.num_rows() == 0 {
1463                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1464                }
1465                Poll::Ready(self.process_batch_after_right_end(batch))
1466            }
1467            Some(Err(e)) => Poll::Ready(Err(e)),
1468            None => {
1469                self.set_state(SHJStreamState::BothExhausted {
1470                    final_result: false,
1471                });
1472                Poll::Ready(Ok(StatefulStreamResult::Continue))
1473            }
1474        }
1475    }
1476
1477    /// Asynchronously handles the scenario when the left stream is exhausted.
1478    ///
1479    /// When the left stream is exhausted, this default
1480    /// implementation tries to pull from the right stream and delegates the batch
1481    /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state
1482    /// is updated to indicate so.
1483    ///
1484    /// # Returns
1485    ///
1486    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1487    fn handle_left_stream_end(
1488        &mut self,
1489        cx: &mut Context<'_>,
1490    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1491        match ready!(self.right_stream().poll_next_unpin(cx)) {
1492            Some(Ok(batch)) => {
1493                if batch.num_rows() == 0 {
1494                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1495                }
1496                Poll::Ready(self.process_batch_after_left_end(batch))
1497            }
1498            Some(Err(e)) => Poll::Ready(Err(e)),
1499            None => {
1500                self.set_state(SHJStreamState::BothExhausted {
1501                    final_result: false,
1502                });
1503                Poll::Ready(Ok(StatefulStreamResult::Continue))
1504            }
1505        }
1506    }
1507
1508    /// Handles the state when both streams are exhausted and final results are yet to be produced.
1509    ///
1510    /// This default implementation switches the state to indicate both streams are
1511    /// exhausted with final results and then invokes the handling for this specific
1512    /// scenario via `process_batches_before_finalization`.
1513    ///
1514    /// # Returns
1515    ///
1516    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
1517    fn prepare_for_final_results_after_exhaustion(
1518        &mut self,
1519    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1520        self.set_state(SHJStreamState::BothExhausted { final_result: true });
1521        self.process_batches_before_finalization()
1522    }
1523
1524    fn process_batch_from_right(
1525        &mut self,
1526        batch: RecordBatch,
1527    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1528        self.perform_join_for_given_side(batch, JoinSide::Right)
1529            .map(|maybe_batch| {
1530                if maybe_batch.is_some() {
1531                    StatefulStreamResult::Ready(maybe_batch)
1532                } else {
1533                    StatefulStreamResult::Continue
1534                }
1535            })
1536    }
1537
1538    fn process_batch_from_left(
1539        &mut self,
1540        batch: RecordBatch,
1541    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1542        self.perform_join_for_given_side(batch, JoinSide::Left)
1543            .map(|maybe_batch| {
1544                if maybe_batch.is_some() {
1545                    StatefulStreamResult::Ready(maybe_batch)
1546                } else {
1547                    StatefulStreamResult::Continue
1548                }
1549            })
1550    }
1551
1552    fn process_batch_after_left_end(
1553        &mut self,
1554        right_batch: RecordBatch,
1555    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1556        self.process_batch_from_right(right_batch)
1557    }
1558
1559    fn process_batch_after_right_end(
1560        &mut self,
1561        left_batch: RecordBatch,
1562    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1563        self.process_batch_from_left(left_batch)
1564    }
1565
1566    fn process_batches_before_finalization(
1567        &mut self,
1568    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1569        // Get the left side results:
1570        let left_result = build_side_determined_results(
1571            &self.left,
1572            &self.schema,
1573            self.left.input_buffer.num_rows(),
1574            self.right.input_buffer.schema(),
1575            self.join_type,
1576            &self.column_indices,
1577        )?;
1578        // Get the right side results:
1579        let right_result = build_side_determined_results(
1580            &self.right,
1581            &self.schema,
1582            self.right.input_buffer.num_rows(),
1583            self.left.input_buffer.schema(),
1584            self.join_type,
1585            &self.column_indices,
1586        )?;
1587
1588        // Combine the left and right results:
1589        let result = combine_two_batches(&self.schema, left_result, right_result)?;
1590
1591        // Return the result:
1592        if result.is_some() {
1593            return Ok(StatefulStreamResult::Ready(result));
1594        }
1595        Ok(StatefulStreamResult::Continue)
1596    }
1597
1598    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
1599        &mut self.right_stream
1600    }
1601
1602    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
1603        &mut self.left_stream
1604    }
1605
1606    fn set_state(&mut self, state: SHJStreamState) {
1607        self.state = state;
1608    }
1609
1610    fn state(&mut self) -> SHJStreamState {
1611        self.state.clone()
1612    }
1613
1614    fn size(&self) -> usize {
1615        let mut size = 0;
1616        size += size_of_val(&self.schema);
1617        size += size_of_val(&self.filter);
1618        size += size_of_val(&self.join_type);
1619        size += self.left.size();
1620        size += self.right.size();
1621        size += size_of_val(&self.column_indices);
1622        size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
1623        size += size_of_val(&self.left_sorted_filter_expr);
1624        size += size_of_val(&self.right_sorted_filter_expr);
1625        size += size_of_val(&self.random_state);
1626        size += size_of_val(&self.null_equality);
1627        size += size_of_val(&self.metrics);
1628        size
1629    }
1630
1631    /// Performs a join operation for the specified `probe_side` (either left or right).
1632    /// This function:
1633    /// 1. Determines which side is the probe and which is the build side.
1634    /// 2. Updates metrics based on the batch that was polled.
1635    /// 3. Executes the join with the given `probe_batch`.
1636    /// 4. Optionally computes anti-join results if all conditions are met.
1637    /// 5. Combines the results and returns a combined batch or `None` if no batch was produced.
1638    fn perform_join_for_given_side(
1639        &mut self,
1640        probe_batch: RecordBatch,
1641        probe_side: JoinSide,
1642    ) -> Result<Option<RecordBatch>> {
1643        let (
1644            probe_hash_joiner,
1645            build_hash_joiner,
1646            probe_side_sorted_filter_expr,
1647            build_side_sorted_filter_expr,
1648            probe_side_metrics,
1649        ) = if probe_side.eq(&JoinSide::Left) {
1650            (
1651                &mut self.left,
1652                &mut self.right,
1653                &mut self.left_sorted_filter_expr,
1654                &mut self.right_sorted_filter_expr,
1655                &mut self.metrics.left,
1656            )
1657        } else {
1658            (
1659                &mut self.right,
1660                &mut self.left,
1661                &mut self.right_sorted_filter_expr,
1662                &mut self.left_sorted_filter_expr,
1663                &mut self.metrics.right,
1664            )
1665        };
1666        // Update the metrics for the stream that was polled:
1667        probe_side_metrics.input_batches.add(1);
1668        probe_side_metrics.input_rows.add(probe_batch.num_rows());
1669        // Update the internal state of the hash joiner for the build side:
1670        probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?;
1671        // Join the two sides:
1672        let equal_result = join_with_probe_batch(
1673            build_hash_joiner,
1674            probe_hash_joiner,
1675            &self.schema,
1676            self.join_type,
1677            self.filter.as_ref(),
1678            &probe_batch,
1679            &self.column_indices,
1680            &self.random_state,
1681            self.null_equality,
1682        )?;
1683        // Increment the offset for the probe hash joiner:
1684        probe_hash_joiner.offset += probe_batch.num_rows();
1685
1686        let anti_result = if let (
1687            Some(build_side_sorted_filter_expr),
1688            Some(probe_side_sorted_filter_expr),
1689            Some(graph),
1690        ) = (
1691            build_side_sorted_filter_expr.as_mut(),
1692            probe_side_sorted_filter_expr.as_mut(),
1693            self.graph.as_mut(),
1694        ) {
1695            // Calculate filter intervals:
1696            calculate_filter_expr_intervals(
1697                &build_hash_joiner.input_buffer,
1698                build_side_sorted_filter_expr,
1699                &probe_batch,
1700                probe_side_sorted_filter_expr,
1701            )?;
1702            let prune_length = build_hash_joiner
1703                .calculate_prune_length_with_probe_batch(
1704                    build_side_sorted_filter_expr,
1705                    probe_side_sorted_filter_expr,
1706                    graph,
1707                )?;
1708            let result = build_side_determined_results(
1709                build_hash_joiner,
1710                &self.schema,
1711                prune_length,
1712                probe_batch.schema(),
1713                self.join_type,
1714                &self.column_indices,
1715            )?;
1716            build_hash_joiner.prune_internal_state(prune_length)?;
1717            result
1718        } else {
1719            None
1720        };
1721
1722        // Combine results:
1723        let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
1724        let capacity = self.size();
1725        self.metrics.stream_memory_usage.set(capacity);
1726        self.reservation.lock().try_resize(capacity)?;
1727        Ok(result)
1728    }
1729}
1730
1731/// Represents the various states of an symmetric hash join stream operation.
1732///
1733/// This enum is used to track the current state of streaming during a join
1734/// operation. It provides indicators as to which side of the join needs to be
1735/// pulled next or if one (or both) sides have been exhausted. This allows
1736/// for efficient management of resources and optimal performance during the
1737/// join process.
1738#[derive(Clone, Debug)]
1739pub enum SHJStreamState {
1740    /// Indicates that the next step should pull from the right side of the join.
1741    PullRight,
1742
1743    /// Indicates that the next step should pull from the left side of the join.
1744    PullLeft,
1745
1746    /// State representing that the right side of the join has been fully processed.
1747    RightExhausted,
1748
1749    /// State representing that the left side of the join has been fully processed.
1750    LeftExhausted,
1751
1752    /// Represents a state where both sides of the join are exhausted.
1753    ///
1754    /// The `final_result` field indicates whether the join operation has
1755    /// produced a final result or not.
1756    BothExhausted { final_result: bool },
1757}
1758
1759#[cfg(test)]
1760mod tests {
1761    use std::collections::HashMap;
1762    use std::sync::{LazyLock, Mutex};
1763
1764    use super::*;
1765    use crate::joins::test_utils::{
1766        build_sides_record_batches, compare_batches, complicated_filter,
1767        create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
1768        join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
1769        partitioned_sym_join_with_filter, split_record_batches,
1770    };
1771
1772    use arrow::compute::SortOptions;
1773    use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
1774    use datafusion_common::ScalarValue;
1775    use datafusion_execution::config::SessionConfig;
1776    use datafusion_expr::Operator;
1777    use datafusion_physical_expr::expressions::{binary, col, lit, Column};
1778    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
1779
1780    use rstest::*;
1781
1782    const TABLE_SIZE: i32 = 30;
1783
1784    type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size)
1785    type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); // (left, right)
1786
1787    // Cache for storing tables
1788    static TABLE_CACHE: LazyLock<Mutex<HashMap<TableKey, TableValue>>> =
1789        LazyLock::new(|| Mutex::new(HashMap::new()));
1790
1791    fn get_or_create_table(
1792        cardinality: (i32, i32),
1793        batch_size: usize,
1794    ) -> Result<TableValue> {
1795        {
1796            let cache = TABLE_CACHE.lock().unwrap();
1797            if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
1798                return Ok(table.clone());
1799            }
1800        }
1801
1802        // If not, create the table
1803        let (left_batch, right_batch) =
1804            build_sides_record_batches(TABLE_SIZE, cardinality)?;
1805
1806        let (left_partition, right_partition) = (
1807            split_record_batches(&left_batch, batch_size)?,
1808            split_record_batches(&right_batch, batch_size)?,
1809        );
1810
1811        // Lock the cache again and store the table
1812        let mut cache = TABLE_CACHE.lock().unwrap();
1813
1814        // Store the table in the cache
1815        cache.insert(
1816            (cardinality.0, cardinality.1, batch_size),
1817            (left_partition.clone(), right_partition.clone()),
1818        );
1819
1820        Ok((left_partition, right_partition))
1821    }
1822
1823    pub async fn experiment(
1824        left: Arc<dyn ExecutionPlan>,
1825        right: Arc<dyn ExecutionPlan>,
1826        filter: Option<JoinFilter>,
1827        join_type: JoinType,
1828        on: JoinOn,
1829        task_ctx: Arc<TaskContext>,
1830    ) -> Result<()> {
1831        let first_batches = partitioned_sym_join_with_filter(
1832            Arc::clone(&left),
1833            Arc::clone(&right),
1834            on.clone(),
1835            filter.clone(),
1836            &join_type,
1837            NullEquality::NullEqualsNothing,
1838            Arc::clone(&task_ctx),
1839        )
1840        .await?;
1841        let second_batches = partitioned_hash_join_with_filter(
1842            left,
1843            right,
1844            on,
1845            filter,
1846            &join_type,
1847            NullEquality::NullEqualsNothing,
1848            task_ctx,
1849        )
1850        .await?;
1851        compare_batches(&first_batches, &second_batches);
1852        Ok(())
1853    }
1854
1855    #[rstest]
1856    #[tokio::test(flavor = "multi_thread")]
1857    async fn complex_join_all_one_ascending_numeric(
1858        #[values(
1859            JoinType::Inner,
1860            JoinType::Left,
1861            JoinType::Right,
1862            JoinType::RightSemi,
1863            JoinType::LeftSemi,
1864            JoinType::LeftAnti,
1865            JoinType::LeftMark,
1866            JoinType::RightAnti,
1867            JoinType::RightMark,
1868            JoinType::Full
1869        )]
1870        join_type: JoinType,
1871        #[values(
1872        (4, 5),
1873        (12, 17),
1874        )]
1875        cardinality: (i32, i32),
1876    ) -> Result<()> {
1877        // a + b > c + 10 AND a + b < c + 100
1878        let task_ctx = Arc::new(TaskContext::default());
1879
1880        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
1881
1882        let left_schema = &left_partition[0].schema();
1883        let right_schema = &right_partition[0].schema();
1884
1885        let left_sorted = [PhysicalSortExpr {
1886            expr: binary(
1887                col("la1", left_schema)?,
1888                Operator::Plus,
1889                col("la2", left_schema)?,
1890                left_schema,
1891            )?,
1892            options: SortOptions::default(),
1893        }]
1894        .into();
1895        let right_sorted = [PhysicalSortExpr {
1896            expr: col("ra1", right_schema)?,
1897            options: SortOptions::default(),
1898        }]
1899        .into();
1900        let (left, right) = create_memory_table(
1901            left_partition,
1902            right_partition,
1903            vec![left_sorted],
1904            vec![right_sorted],
1905        )?;
1906
1907        let on = vec![(
1908            binary(
1909                col("lc1", left_schema)?,
1910                Operator::Plus,
1911                lit(ScalarValue::Int32(Some(1))),
1912                left_schema,
1913            )?,
1914            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1915        )];
1916
1917        let intermediate_schema = Schema::new(vec![
1918            Field::new("0", DataType::Int32, true),
1919            Field::new("1", DataType::Int32, true),
1920            Field::new("2", DataType::Int32, true),
1921        ]);
1922        let filter_expr = complicated_filter(&intermediate_schema)?;
1923        let column_indices = vec![
1924            ColumnIndex {
1925                index: left_schema.index_of("la1")?,
1926                side: JoinSide::Left,
1927            },
1928            ColumnIndex {
1929                index: left_schema.index_of("la2")?,
1930                side: JoinSide::Left,
1931            },
1932            ColumnIndex {
1933                index: right_schema.index_of("ra1")?,
1934                side: JoinSide::Right,
1935            },
1936        ];
1937        let filter =
1938            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1939
1940        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1941        Ok(())
1942    }
1943
1944    #[rstest]
1945    #[tokio::test(flavor = "multi_thread")]
1946    async fn join_all_one_ascending_numeric(
1947        #[values(
1948            JoinType::Inner,
1949            JoinType::Left,
1950            JoinType::Right,
1951            JoinType::RightSemi,
1952            JoinType::LeftSemi,
1953            JoinType::LeftAnti,
1954            JoinType::LeftMark,
1955            JoinType::RightAnti,
1956            JoinType::RightMark,
1957            JoinType::Full
1958        )]
1959        join_type: JoinType,
1960        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1961    ) -> Result<()> {
1962        let task_ctx = Arc::new(TaskContext::default());
1963        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1964
1965        let left_schema = &left_partition[0].schema();
1966        let right_schema = &right_partition[0].schema();
1967
1968        let left_sorted = [PhysicalSortExpr {
1969            expr: col("la1", left_schema)?,
1970            options: SortOptions::default(),
1971        }]
1972        .into();
1973        let right_sorted = [PhysicalSortExpr {
1974            expr: col("ra1", right_schema)?,
1975            options: SortOptions::default(),
1976        }]
1977        .into();
1978        let (left, right) = create_memory_table(
1979            left_partition,
1980            right_partition,
1981            vec![left_sorted],
1982            vec![right_sorted],
1983        )?;
1984
1985        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
1986
1987        let intermediate_schema = Schema::new(vec![
1988            Field::new("left", DataType::Int32, true),
1989            Field::new("right", DataType::Int32, true),
1990        ]);
1991        let filter_expr = join_expr_tests_fixture_i32(
1992            case_expr,
1993            col("left", &intermediate_schema)?,
1994            col("right", &intermediate_schema)?,
1995        );
1996        let column_indices = vec![
1997            ColumnIndex {
1998                index: 0,
1999                side: JoinSide::Left,
2000            },
2001            ColumnIndex {
2002                index: 0,
2003                side: JoinSide::Right,
2004            },
2005        ];
2006        let filter =
2007            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2008
2009        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2010        Ok(())
2011    }
2012
2013    #[rstest]
2014    #[tokio::test(flavor = "multi_thread")]
2015    async fn join_without_sort_information(
2016        #[values(
2017            JoinType::Inner,
2018            JoinType::Left,
2019            JoinType::Right,
2020            JoinType::RightSemi,
2021            JoinType::LeftSemi,
2022            JoinType::LeftAnti,
2023            JoinType::LeftMark,
2024            JoinType::RightAnti,
2025            JoinType::RightMark,
2026            JoinType::Full
2027        )]
2028        join_type: JoinType,
2029        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2030    ) -> Result<()> {
2031        let task_ctx = Arc::new(TaskContext::default());
2032        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
2033
2034        let left_schema = &left_partition[0].schema();
2035        let right_schema = &right_partition[0].schema();
2036        let (left, right) =
2037            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2038
2039        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2040
2041        let intermediate_schema = Schema::new(vec![
2042            Field::new("left", DataType::Int32, true),
2043            Field::new("right", DataType::Int32, true),
2044        ]);
2045        let filter_expr = join_expr_tests_fixture_i32(
2046            case_expr,
2047            col("left", &intermediate_schema)?,
2048            col("right", &intermediate_schema)?,
2049        );
2050        let column_indices = vec![
2051            ColumnIndex {
2052                index: 5,
2053                side: JoinSide::Left,
2054            },
2055            ColumnIndex {
2056                index: 5,
2057                side: JoinSide::Right,
2058            },
2059        ];
2060        let filter =
2061            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2062
2063        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2064        Ok(())
2065    }
2066
2067    #[rstest]
2068    #[tokio::test(flavor = "multi_thread")]
2069    async fn join_without_filter(
2070        #[values(
2071            JoinType::Inner,
2072            JoinType::Left,
2073            JoinType::Right,
2074            JoinType::RightSemi,
2075            JoinType::LeftSemi,
2076            JoinType::LeftAnti,
2077            JoinType::LeftMark,
2078            JoinType::RightAnti,
2079            JoinType::RightMark,
2080            JoinType::Full
2081        )]
2082        join_type: JoinType,
2083    ) -> Result<()> {
2084        let task_ctx = Arc::new(TaskContext::default());
2085        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2086        let left_schema = &left_partition[0].schema();
2087        let right_schema = &right_partition[0].schema();
2088        let (left, right) =
2089            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2090
2091        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2092        experiment(left, right, None, join_type, on, task_ctx).await?;
2093        Ok(())
2094    }
2095
2096    #[rstest]
2097    #[tokio::test(flavor = "multi_thread")]
2098    async fn join_all_one_descending_numeric_particular(
2099        #[values(
2100            JoinType::Inner,
2101            JoinType::Left,
2102            JoinType::Right,
2103            JoinType::RightSemi,
2104            JoinType::LeftSemi,
2105            JoinType::LeftAnti,
2106            JoinType::LeftMark,
2107            JoinType::RightAnti,
2108            JoinType::RightMark,
2109            JoinType::Full
2110        )]
2111        join_type: JoinType,
2112        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2113    ) -> Result<()> {
2114        let task_ctx = Arc::new(TaskContext::default());
2115        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2116        let left_schema = &left_partition[0].schema();
2117        let right_schema = &right_partition[0].schema();
2118        let left_sorted = [PhysicalSortExpr {
2119            expr: col("la1_des", left_schema)?,
2120            options: SortOptions {
2121                descending: true,
2122                nulls_first: true,
2123            },
2124        }]
2125        .into();
2126        let right_sorted = [PhysicalSortExpr {
2127            expr: col("ra1_des", right_schema)?,
2128            options: SortOptions {
2129                descending: true,
2130                nulls_first: true,
2131            },
2132        }]
2133        .into();
2134        let (left, right) = create_memory_table(
2135            left_partition,
2136            right_partition,
2137            vec![left_sorted],
2138            vec![right_sorted],
2139        )?;
2140
2141        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2142
2143        let intermediate_schema = Schema::new(vec![
2144            Field::new("left", DataType::Int32, true),
2145            Field::new("right", DataType::Int32, true),
2146        ]);
2147        let filter_expr = join_expr_tests_fixture_i32(
2148            case_expr,
2149            col("left", &intermediate_schema)?,
2150            col("right", &intermediate_schema)?,
2151        );
2152        let column_indices = vec![
2153            ColumnIndex {
2154                index: 5,
2155                side: JoinSide::Left,
2156            },
2157            ColumnIndex {
2158                index: 5,
2159                side: JoinSide::Right,
2160            },
2161        ];
2162        let filter =
2163            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2164
2165        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2166        Ok(())
2167    }
2168
2169    #[tokio::test(flavor = "multi_thread")]
2170    async fn build_null_columns_first() -> Result<()> {
2171        let join_type = JoinType::Full;
2172        let case_expr = 1;
2173        let session_config = SessionConfig::new().with_repartition_joins(false);
2174        let task_ctx = TaskContext::default().with_session_config(session_config);
2175        let task_ctx = Arc::new(task_ctx);
2176        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2177        let left_schema = &left_partition[0].schema();
2178        let right_schema = &right_partition[0].schema();
2179        let left_sorted = [PhysicalSortExpr {
2180            expr: col("l_asc_null_first", left_schema)?,
2181            options: SortOptions {
2182                descending: false,
2183                nulls_first: true,
2184            },
2185        }]
2186        .into();
2187        let right_sorted = [PhysicalSortExpr {
2188            expr: col("r_asc_null_first", right_schema)?,
2189            options: SortOptions {
2190                descending: false,
2191                nulls_first: true,
2192            },
2193        }]
2194        .into();
2195        let (left, right) = create_memory_table(
2196            left_partition,
2197            right_partition,
2198            vec![left_sorted],
2199            vec![right_sorted],
2200        )?;
2201
2202        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2203
2204        let intermediate_schema = Schema::new(vec![
2205            Field::new("left", DataType::Int32, true),
2206            Field::new("right", DataType::Int32, true),
2207        ]);
2208        let filter_expr = join_expr_tests_fixture_i32(
2209            case_expr,
2210            col("left", &intermediate_schema)?,
2211            col("right", &intermediate_schema)?,
2212        );
2213        let column_indices = vec![
2214            ColumnIndex {
2215                index: 6,
2216                side: JoinSide::Left,
2217            },
2218            ColumnIndex {
2219                index: 6,
2220                side: JoinSide::Right,
2221            },
2222        ];
2223        let filter =
2224            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2225        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2226        Ok(())
2227    }
2228
2229    #[tokio::test(flavor = "multi_thread")]
2230    async fn build_null_columns_last() -> Result<()> {
2231        let join_type = JoinType::Full;
2232        let case_expr = 1;
2233        let session_config = SessionConfig::new().with_repartition_joins(false);
2234        let task_ctx = TaskContext::default().with_session_config(session_config);
2235        let task_ctx = Arc::new(task_ctx);
2236        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2237
2238        let left_schema = &left_partition[0].schema();
2239        let right_schema = &right_partition[0].schema();
2240        let left_sorted = [PhysicalSortExpr {
2241            expr: col("l_asc_null_last", left_schema)?,
2242            options: SortOptions {
2243                descending: false,
2244                nulls_first: false,
2245            },
2246        }]
2247        .into();
2248        let right_sorted = [PhysicalSortExpr {
2249            expr: col("r_asc_null_last", right_schema)?,
2250            options: SortOptions {
2251                descending: false,
2252                nulls_first: false,
2253            },
2254        }]
2255        .into();
2256        let (left, right) = create_memory_table(
2257            left_partition,
2258            right_partition,
2259            vec![left_sorted],
2260            vec![right_sorted],
2261        )?;
2262
2263        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2264
2265        let intermediate_schema = Schema::new(vec![
2266            Field::new("left", DataType::Int32, true),
2267            Field::new("right", DataType::Int32, true),
2268        ]);
2269        let filter_expr = join_expr_tests_fixture_i32(
2270            case_expr,
2271            col("left", &intermediate_schema)?,
2272            col("right", &intermediate_schema)?,
2273        );
2274        let column_indices = vec![
2275            ColumnIndex {
2276                index: 7,
2277                side: JoinSide::Left,
2278            },
2279            ColumnIndex {
2280                index: 7,
2281                side: JoinSide::Right,
2282            },
2283        ];
2284        let filter =
2285            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2286
2287        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2288        Ok(())
2289    }
2290
2291    #[tokio::test(flavor = "multi_thread")]
2292    async fn build_null_columns_first_descending() -> Result<()> {
2293        let join_type = JoinType::Full;
2294        let cardinality = (10, 11);
2295        let case_expr = 1;
2296        let session_config = SessionConfig::new().with_repartition_joins(false);
2297        let task_ctx = TaskContext::default().with_session_config(session_config);
2298        let task_ctx = Arc::new(task_ctx);
2299        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2300
2301        let left_schema = &left_partition[0].schema();
2302        let right_schema = &right_partition[0].schema();
2303        let left_sorted = [PhysicalSortExpr {
2304            expr: col("l_desc_null_first", left_schema)?,
2305            options: SortOptions {
2306                descending: true,
2307                nulls_first: true,
2308            },
2309        }]
2310        .into();
2311        let right_sorted = [PhysicalSortExpr {
2312            expr: col("r_desc_null_first", right_schema)?,
2313            options: SortOptions {
2314                descending: true,
2315                nulls_first: true,
2316            },
2317        }]
2318        .into();
2319        let (left, right) = create_memory_table(
2320            left_partition,
2321            right_partition,
2322            vec![left_sorted],
2323            vec![right_sorted],
2324        )?;
2325
2326        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2327
2328        let intermediate_schema = Schema::new(vec![
2329            Field::new("left", DataType::Int32, true),
2330            Field::new("right", DataType::Int32, true),
2331        ]);
2332        let filter_expr = join_expr_tests_fixture_i32(
2333            case_expr,
2334            col("left", &intermediate_schema)?,
2335            col("right", &intermediate_schema)?,
2336        );
2337        let column_indices = vec![
2338            ColumnIndex {
2339                index: 8,
2340                side: JoinSide::Left,
2341            },
2342            ColumnIndex {
2343                index: 8,
2344                side: JoinSide::Right,
2345            },
2346        ];
2347        let filter =
2348            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2349
2350        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2351        Ok(())
2352    }
2353
2354    #[tokio::test(flavor = "multi_thread")]
2355    async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
2356        let cardinality = (3, 4);
2357        let join_type = JoinType::Full;
2358
2359        // a + b > c + 10 AND a + b < c + 100
2360        let session_config = SessionConfig::new().with_repartition_joins(false);
2361        let task_ctx = TaskContext::default().with_session_config(session_config);
2362        let task_ctx = Arc::new(task_ctx);
2363        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2364
2365        let left_schema = &left_partition[0].schema();
2366        let right_schema = &right_partition[0].schema();
2367        let left_sorted = [PhysicalSortExpr {
2368            expr: col("la1", left_schema)?,
2369            options: SortOptions::default(),
2370        }]
2371        .into();
2372        let right_sorted = [PhysicalSortExpr {
2373            expr: col("ra1", right_schema)?,
2374            options: SortOptions::default(),
2375        }]
2376        .into();
2377        let (left, right) = create_memory_table(
2378            left_partition,
2379            right_partition,
2380            vec![left_sorted],
2381            vec![right_sorted],
2382        )?;
2383
2384        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2385
2386        let intermediate_schema = Schema::new(vec![
2387            Field::new("0", DataType::Int32, true),
2388            Field::new("1", DataType::Int32, true),
2389            Field::new("2", DataType::Int32, true),
2390        ]);
2391        let filter_expr = complicated_filter(&intermediate_schema)?;
2392        let column_indices = vec![
2393            ColumnIndex {
2394                index: 0,
2395                side: JoinSide::Left,
2396            },
2397            ColumnIndex {
2398                index: 4,
2399                side: JoinSide::Left,
2400            },
2401            ColumnIndex {
2402                index: 0,
2403                side: JoinSide::Right,
2404            },
2405        ];
2406        let filter =
2407            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2408
2409        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2410        Ok(())
2411    }
2412
2413    #[tokio::test(flavor = "multi_thread")]
2414    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
2415        let cardinality = (3, 4);
2416        let join_type = JoinType::Full;
2417
2418        // a + b > c + 10 AND a + b < c + 100
2419        let config = SessionConfig::new().with_repartition_joins(false);
2420        // let session_ctx = SessionContext::with_config(config);
2421        // let task_ctx = session_ctx.task_ctx();
2422        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
2423        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2424        let left_schema = &left_partition[0].schema();
2425        let right_schema = &right_partition[0].schema();
2426        let left_sorted = vec![
2427            [PhysicalSortExpr {
2428                expr: col("la1", left_schema)?,
2429                options: SortOptions::default(),
2430            }]
2431            .into(),
2432            [PhysicalSortExpr {
2433                expr: col("la2", left_schema)?,
2434                options: SortOptions::default(),
2435            }]
2436            .into(),
2437        ];
2438
2439        let right_sorted = [PhysicalSortExpr {
2440            expr: col("ra1", right_schema)?,
2441            options: SortOptions::default(),
2442        }]
2443        .into();
2444
2445        let (left, right) = create_memory_table(
2446            left_partition,
2447            right_partition,
2448            left_sorted,
2449            vec![right_sorted],
2450        )?;
2451
2452        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2453
2454        let intermediate_schema = Schema::new(vec![
2455            Field::new("0", DataType::Int32, true),
2456            Field::new("1", DataType::Int32, true),
2457            Field::new("2", DataType::Int32, true),
2458        ]);
2459        let filter_expr = complicated_filter(&intermediate_schema)?;
2460        let column_indices = vec![
2461            ColumnIndex {
2462                index: 0,
2463                side: JoinSide::Left,
2464            },
2465            ColumnIndex {
2466                index: 4,
2467                side: JoinSide::Left,
2468            },
2469            ColumnIndex {
2470                index: 0,
2471                side: JoinSide::Right,
2472            },
2473        ];
2474        let filter =
2475            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2476
2477        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2478        Ok(())
2479    }
2480
2481    #[rstest]
2482    #[tokio::test(flavor = "multi_thread")]
2483    async fn testing_with_temporal_columns(
2484        #[values(
2485            JoinType::Inner,
2486            JoinType::Left,
2487            JoinType::Right,
2488            JoinType::RightSemi,
2489            JoinType::LeftSemi,
2490            JoinType::LeftAnti,
2491            JoinType::LeftMark,
2492            JoinType::RightAnti,
2493            JoinType::RightMark,
2494            JoinType::Full
2495        )]
2496        join_type: JoinType,
2497        #[values(
2498            (4, 5),
2499            (12, 17),
2500        )]
2501        cardinality: (i32, i32),
2502        #[values(0, 1, 2)] case_expr: usize,
2503    ) -> Result<()> {
2504        let session_config = SessionConfig::new().with_repartition_joins(false);
2505        let task_ctx = TaskContext::default().with_session_config(session_config);
2506        let task_ctx = Arc::new(task_ctx);
2507        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2508
2509        let left_schema = &left_partition[0].schema();
2510        let right_schema = &right_partition[0].schema();
2511        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2512        let left_sorted = [PhysicalSortExpr {
2513            expr: col("lt1", left_schema)?,
2514            options: SortOptions {
2515                descending: false,
2516                nulls_first: true,
2517            },
2518        }]
2519        .into();
2520        let right_sorted = [PhysicalSortExpr {
2521            expr: col("rt1", right_schema)?,
2522            options: SortOptions {
2523                descending: false,
2524                nulls_first: true,
2525            },
2526        }]
2527        .into();
2528        let (left, right) = create_memory_table(
2529            left_partition,
2530            right_partition,
2531            vec![left_sorted],
2532            vec![right_sorted],
2533        )?;
2534        let intermediate_schema = Schema::new(vec![
2535            Field::new(
2536                "left",
2537                DataType::Timestamp(TimeUnit::Millisecond, None),
2538                false,
2539            ),
2540            Field::new(
2541                "right",
2542                DataType::Timestamp(TimeUnit::Millisecond, None),
2543                false,
2544            ),
2545        ]);
2546        let filter_expr = join_expr_tests_fixture_temporal(
2547            case_expr,
2548            col("left", &intermediate_schema)?,
2549            col("right", &intermediate_schema)?,
2550            &intermediate_schema,
2551        )?;
2552        let column_indices = vec![
2553            ColumnIndex {
2554                index: 3,
2555                side: JoinSide::Left,
2556            },
2557            ColumnIndex {
2558                index: 3,
2559                side: JoinSide::Right,
2560            },
2561        ];
2562        let filter =
2563            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2564        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2565        Ok(())
2566    }
2567
2568    #[rstest]
2569    #[tokio::test(flavor = "multi_thread")]
2570    async fn test_with_interval_columns(
2571        #[values(
2572            JoinType::Inner,
2573            JoinType::Left,
2574            JoinType::Right,
2575            JoinType::RightSemi,
2576            JoinType::LeftSemi,
2577            JoinType::LeftAnti,
2578            JoinType::LeftMark,
2579            JoinType::RightAnti,
2580            JoinType::RightMark,
2581            JoinType::Full
2582        )]
2583        join_type: JoinType,
2584        #[values(
2585            (4, 5),
2586            (12, 17),
2587        )]
2588        cardinality: (i32, i32),
2589    ) -> Result<()> {
2590        let session_config = SessionConfig::new().with_repartition_joins(false);
2591        let task_ctx = TaskContext::default().with_session_config(session_config);
2592        let task_ctx = Arc::new(task_ctx);
2593        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2594
2595        let left_schema = &left_partition[0].schema();
2596        let right_schema = &right_partition[0].schema();
2597        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2598        let left_sorted = [PhysicalSortExpr {
2599            expr: col("li1", left_schema)?,
2600            options: SortOptions {
2601                descending: false,
2602                nulls_first: true,
2603            },
2604        }]
2605        .into();
2606        let right_sorted = [PhysicalSortExpr {
2607            expr: col("ri1", right_schema)?,
2608            options: SortOptions {
2609                descending: false,
2610                nulls_first: true,
2611            },
2612        }]
2613        .into();
2614        let (left, right) = create_memory_table(
2615            left_partition,
2616            right_partition,
2617            vec![left_sorted],
2618            vec![right_sorted],
2619        )?;
2620        let intermediate_schema = Schema::new(vec![
2621            Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
2622            Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
2623        ]);
2624        let filter_expr = join_expr_tests_fixture_temporal(
2625            0,
2626            col("left", &intermediate_schema)?,
2627            col("right", &intermediate_schema)?,
2628            &intermediate_schema,
2629        )?;
2630        let column_indices = vec![
2631            ColumnIndex {
2632                index: 9,
2633                side: JoinSide::Left,
2634            },
2635            ColumnIndex {
2636                index: 9,
2637                side: JoinSide::Right,
2638            },
2639        ];
2640        let filter =
2641            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2642        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2643
2644        Ok(())
2645    }
2646
2647    #[rstest]
2648    #[tokio::test(flavor = "multi_thread")]
2649    async fn testing_ascending_float_pruning(
2650        #[values(
2651            JoinType::Inner,
2652            JoinType::Left,
2653            JoinType::Right,
2654            JoinType::RightSemi,
2655            JoinType::LeftSemi,
2656            JoinType::LeftAnti,
2657            JoinType::LeftMark,
2658            JoinType::RightAnti,
2659            JoinType::RightMark,
2660            JoinType::Full
2661        )]
2662        join_type: JoinType,
2663        #[values(
2664            (4, 5),
2665            (12, 17),
2666        )]
2667        cardinality: (i32, i32),
2668        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2669    ) -> Result<()> {
2670        let session_config = SessionConfig::new().with_repartition_joins(false);
2671        let task_ctx = TaskContext::default().with_session_config(session_config);
2672        let task_ctx = Arc::new(task_ctx);
2673        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2674
2675        let left_schema = &left_partition[0].schema();
2676        let right_schema = &right_partition[0].schema();
2677        let left_sorted = [PhysicalSortExpr {
2678            expr: col("l_float", left_schema)?,
2679            options: SortOptions::default(),
2680        }]
2681        .into();
2682        let right_sorted = [PhysicalSortExpr {
2683            expr: col("r_float", right_schema)?,
2684            options: SortOptions::default(),
2685        }]
2686        .into();
2687        let (left, right) = create_memory_table(
2688            left_partition,
2689            right_partition,
2690            vec![left_sorted],
2691            vec![right_sorted],
2692        )?;
2693
2694        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2695
2696        let intermediate_schema = Schema::new(vec![
2697            Field::new("left", DataType::Float64, true),
2698            Field::new("right", DataType::Float64, true),
2699        ]);
2700        let filter_expr = join_expr_tests_fixture_f64(
2701            case_expr,
2702            col("left", &intermediate_schema)?,
2703            col("right", &intermediate_schema)?,
2704        );
2705        let column_indices = vec![
2706            ColumnIndex {
2707                index: 10, // l_float
2708                side: JoinSide::Left,
2709            },
2710            ColumnIndex {
2711                index: 10, // r_float
2712                side: JoinSide::Right,
2713            },
2714        ];
2715        let filter =
2716            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2717
2718        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2719        Ok(())
2720    }
2721}