datafusion_physical_plan/joins/
symmetric_hash_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the symmetric hash join algorithm with range-based
19//! data pruning to join two (potentially infinite) streams.
20//!
21//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate
22//! output ordering) and produces the join output according to the given join
23//! type and other options.
24//!
25//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations
26//! for both its children.
27
28use std::any::Any;
29use std::fmt::{self, Debug};
30use std::mem::{size_of, size_of_val};
31use std::sync::Arc;
32use std::task::{Context, Poll};
33use std::vec;
34
35use crate::common::SharedMemoryReservation;
36use crate::execution_plan::{boundedness_from_children, emission_type_from_children};
37use crate::joins::stream_join_utils::{
38    calculate_filter_expr_intervals, combine_two_batches,
39    convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
40    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
41    PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
42};
43use crate::joins::utils::{
44    apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
45    check_join_is_valid, equal_rows_arr, symmetric_join_output_partitioning, update_hash,
46    BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn,
47    JoinOnRef, NoopBatchTransformer, StatefulStreamResult,
48};
49use crate::projection::{
50    join_allows_pushdown, join_table_borders, new_join_children,
51    physical_to_column_exprs, update_join_filter, update_join_on, ProjectionExec,
52};
53use crate::{
54    joins::StreamJoinPartitionMode,
55    metrics::{ExecutionPlanMetricsSet, MetricsSet},
56    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
57    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
58};
59
60use arrow::array::{
61    ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
62    UInt64Array,
63};
64use arrow::compute::concat_batches;
65use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
66use arrow::record_batch::RecordBatch;
67use datafusion_common::hash_utils::create_hashes;
68use datafusion_common::utils::bisect;
69use datafusion_common::{
70    internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, Result,
71};
72use datafusion_execution::memory_pool::MemoryConsumer;
73use datafusion_execution::TaskContext;
74use datafusion_expr::interval_arithmetic::Interval;
75use datafusion_physical_expr::equivalence::join_equivalence_properties;
76use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
77use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef};
78use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
79
80use ahash::RandomState;
81use futures::{ready, Stream, StreamExt};
82use parking_lot::Mutex;
83
84const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
85
86/// A symmetric hash join with range conditions is when both streams are hashed on the
87/// join key and the resulting hash tables are used to join the streams.
88/// The join is considered symmetric because the hash table is built on the join keys from both
89/// streams, and the matching of rows is based on the values of the join keys in both streams.
90/// This type of join is efficient in streaming context as it allows for fast lookups in the hash
91/// table, rather than having to scan through one or both of the streams to find matching rows, also it
92/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions),
93/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming
94/// data without any memory issues.
95///
96/// For each input stream, create a hash table.
97///   - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets.
98///   - Test if input is equal to a predefined set of other inputs.
99///   - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch].
100///   - Try to prune other side (probe) with new [RecordBatch].
101///   - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.),
102///     output the [RecordBatch] when a pruning happens or at the end of the data.
103///
104///
105/// ``` text
106///                        +-------------------------+
107///                        |                         |
108///   left stream ---------|  Left OneSideHashJoiner |---+
109///                        |                         |   |
110///                        +-------------------------+   |
111///                                                      |
112///                                                      |--------- Joined output
113///                                                      |
114///                        +-------------------------+   |
115///                        |                         |   |
116///  right stream ---------| Right OneSideHashJoiner |---+
117///                        |                         |
118///                        +-------------------------+
119///
120/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic
121/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range.
122///
123///
124///               PROBE SIDE          BUILD SIDE
125///                 BUFFER              BUFFER
126///             +-------------+     +------------+
127///             |             |     |            |    Unjoinable
128///             |             |     |            |    Range
129///             |             |     |            |
130///             |             |  |---------------------------------
131///             |             |  |  |            |
132///             |             |  |  |            |
133///             |             | /   |            |
134///             |             | |   |            |
135///             |             | |   |            |
136///             |             | |   |            |
137///             |             | |   |            |
138///             |             | |   |            |    Joinable
139///             |             |/    |            |    Range
140///             |             ||    |            |
141///             |+-----------+||    |            |
142///             || Record    ||     |            |
143///             || Batch     ||     |            |
144///             |+-----------+||    |            |
145///             +-------------+\    +------------+
146///                             |
147///                             \
148///                              |---------------------------------
149///
150///  This happens when range conditions are provided on sorted columns. E.g.
151///
152///        SELECT * FROM left_table, right_table
153///        ON
154///          left_key = right_key AND
155///          left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR
156///
157/// or
158///       SELECT * FROM left_table, right_table
159///        ON
160///          left_key = right_key AND
161///          left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10
162///
163/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to
164/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the
165/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios)
166/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning
167/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" ,
168/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending)
169/// than that can be dropped from the inner buffer.
170/// ```
171#[derive(Debug, Clone)]
172pub struct SymmetricHashJoinExec {
173    /// Left side stream
174    pub(crate) left: Arc<dyn ExecutionPlan>,
175    /// Right side stream
176    pub(crate) right: Arc<dyn ExecutionPlan>,
177    /// Set of common columns used to join on
178    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
179    /// Filters applied when finding matching rows
180    pub(crate) filter: Option<JoinFilter>,
181    /// How the join is performed
182    pub(crate) join_type: JoinType,
183    /// Shares the `RandomState` for the hashing algorithm
184    random_state: RandomState,
185    /// Execution metrics
186    metrics: ExecutionPlanMetricsSet,
187    /// Information of index and left / right placement of columns
188    column_indices: Vec<ColumnIndex>,
189    /// Defines the null equality for the join.
190    pub(crate) null_equality: NullEquality,
191    /// Left side sort expression(s)
192    pub(crate) left_sort_exprs: Option<LexOrdering>,
193    /// Right side sort expression(s)
194    pub(crate) right_sort_exprs: Option<LexOrdering>,
195    /// Partition Mode
196    mode: StreamJoinPartitionMode,
197    /// Cache holding plan properties like equivalences, output partitioning etc.
198    cache: PlanProperties,
199}
200
201impl SymmetricHashJoinExec {
202    /// Tries to create a new [SymmetricHashJoinExec].
203    /// # Error
204    /// This function errors when:
205    /// - It is not possible to join the left and right sides on keys `on`, or
206    /// - It fails to construct `SortedFilterExpr`s, or
207    /// - It fails to create the [ExprIntervalGraph].
208    #[allow(clippy::too_many_arguments)]
209    pub fn try_new(
210        left: Arc<dyn ExecutionPlan>,
211        right: Arc<dyn ExecutionPlan>,
212        on: JoinOn,
213        filter: Option<JoinFilter>,
214        join_type: &JoinType,
215        null_equality: NullEquality,
216        left_sort_exprs: Option<LexOrdering>,
217        right_sort_exprs: Option<LexOrdering>,
218        mode: StreamJoinPartitionMode,
219    ) -> Result<Self> {
220        let left_schema = left.schema();
221        let right_schema = right.schema();
222
223        // Error out if no "on" constraints are given:
224        if on.is_empty() {
225            return plan_err!(
226                "On constraints in SymmetricHashJoinExec should be non-empty"
227            );
228        }
229
230        // Check if the join is valid with the given on constraints:
231        check_join_is_valid(&left_schema, &right_schema, &on)?;
232
233        // Build the join schema from the left and right schemas:
234        let (schema, column_indices) =
235            build_join_schema(&left_schema, &right_schema, join_type);
236
237        // Initialize the random state for the join operation:
238        let random_state = RandomState::with_seeds(0, 0, 0, 0);
239        let schema = Arc::new(schema);
240        let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?;
241        Ok(SymmetricHashJoinExec {
242            left,
243            right,
244            on,
245            filter,
246            join_type: *join_type,
247            random_state,
248            metrics: ExecutionPlanMetricsSet::new(),
249            column_indices,
250            null_equality,
251            left_sort_exprs,
252            right_sort_exprs,
253            mode,
254            cache,
255        })
256    }
257
258    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
259    fn compute_properties(
260        left: &Arc<dyn ExecutionPlan>,
261        right: &Arc<dyn ExecutionPlan>,
262        schema: SchemaRef,
263        join_type: JoinType,
264        join_on: JoinOnRef,
265    ) -> Result<PlanProperties> {
266        // Calculate equivalence properties:
267        let eq_properties = join_equivalence_properties(
268            left.equivalence_properties().clone(),
269            right.equivalence_properties().clone(),
270            &join_type,
271            schema,
272            &[false, false],
273            // Has alternating probe side
274            None,
275            join_on,
276        )?;
277
278        let output_partitioning =
279            symmetric_join_output_partitioning(left, right, &join_type)?;
280
281        Ok(PlanProperties::new(
282            eq_properties,
283            output_partitioning,
284            emission_type_from_children([left, right]),
285            boundedness_from_children([left, right]),
286        ))
287    }
288
289    /// left stream
290    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
291        &self.left
292    }
293
294    /// right stream
295    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
296        &self.right
297    }
298
299    /// Set of common columns used to join on
300    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
301        &self.on
302    }
303
304    /// Filters applied before join output
305    pub fn filter(&self) -> Option<&JoinFilter> {
306        self.filter.as_ref()
307    }
308
309    /// How the join is performed
310    pub fn join_type(&self) -> &JoinType {
311        &self.join_type
312    }
313
314    /// Get null_equality
315    pub fn null_equality(&self) -> NullEquality {
316        self.null_equality
317    }
318
319    /// Get partition mode
320    pub fn partition_mode(&self) -> StreamJoinPartitionMode {
321        self.mode
322    }
323
324    /// Get left_sort_exprs
325    pub fn left_sort_exprs(&self) -> Option<&LexOrdering> {
326        self.left_sort_exprs.as_ref()
327    }
328
329    /// Get right_sort_exprs
330    pub fn right_sort_exprs(&self) -> Option<&LexOrdering> {
331        self.right_sort_exprs.as_ref()
332    }
333
334    /// Check if order information covers every column in the filter expression.
335    pub fn check_if_order_information_available(&self) -> Result<bool> {
336        if let Some(filter) = self.filter() {
337            let left = self.left();
338            if let Some(left_ordering) = left.output_ordering() {
339                let right = self.right();
340                if let Some(right_ordering) = right.output_ordering() {
341                    let left_convertible = convert_sort_expr_with_filter_schema(
342                        &JoinSide::Left,
343                        filter,
344                        &left.schema(),
345                        &left_ordering[0],
346                    )?
347                    .is_some();
348                    let right_convertible = convert_sort_expr_with_filter_schema(
349                        &JoinSide::Right,
350                        filter,
351                        &right.schema(),
352                        &right_ordering[0],
353                    )?
354                    .is_some();
355                    return Ok(left_convertible && right_convertible);
356                }
357            }
358        }
359        Ok(false)
360    }
361}
362
363impl DisplayAs for SymmetricHashJoinExec {
364    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
365        match t {
366            DisplayFormatType::Default | DisplayFormatType::Verbose => {
367                let display_filter = self.filter.as_ref().map_or_else(
368                    || "".to_string(),
369                    |f| format!(", filter={}", f.expression()),
370                );
371                let on = self
372                    .on
373                    .iter()
374                    .map(|(c1, c2)| format!("({c1}, {c2})"))
375                    .collect::<Vec<String>>()
376                    .join(", ");
377                write!(
378                    f,
379                    "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
380                    self.mode, self.join_type, on, display_filter
381                )
382            }
383            DisplayFormatType::TreeRender => {
384                let on = self
385                    .on
386                    .iter()
387                    .map(|(c1, c2)| {
388                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
389                    })
390                    .collect::<Vec<String>>()
391                    .join(", ");
392
393                writeln!(f, "mode={:?}", self.mode)?;
394                if *self.join_type() != JoinType::Inner {
395                    writeln!(f, "join_type={:?}", self.join_type)?;
396                }
397                writeln!(f, "on={on}")
398            }
399        }
400    }
401}
402
403impl ExecutionPlan for SymmetricHashJoinExec {
404    fn name(&self) -> &'static str {
405        "SymmetricHashJoinExec"
406    }
407
408    fn as_any(&self) -> &dyn Any {
409        self
410    }
411
412    fn properties(&self) -> &PlanProperties {
413        &self.cache
414    }
415
416    fn required_input_distribution(&self) -> Vec<Distribution> {
417        match self.mode {
418            StreamJoinPartitionMode::Partitioned => {
419                let (left_expr, right_expr) = self
420                    .on
421                    .iter()
422                    .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
423                    .unzip();
424                vec![
425                    Distribution::HashPartitioned(left_expr),
426                    Distribution::HashPartitioned(right_expr),
427                ]
428            }
429            StreamJoinPartitionMode::SinglePartition => {
430                vec![Distribution::SinglePartition, Distribution::SinglePartition]
431            }
432        }
433    }
434
435    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
436        vec![
437            self.left_sort_exprs
438                .as_ref()
439                .map(|e| OrderingRequirements::from(e.clone())),
440            self.right_sort_exprs
441                .as_ref()
442                .map(|e| OrderingRequirements::from(e.clone())),
443        ]
444    }
445
446    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
447        vec![&self.left, &self.right]
448    }
449
450    fn with_new_children(
451        self: Arc<Self>,
452        children: Vec<Arc<dyn ExecutionPlan>>,
453    ) -> Result<Arc<dyn ExecutionPlan>> {
454        Ok(Arc::new(SymmetricHashJoinExec::try_new(
455            Arc::clone(&children[0]),
456            Arc::clone(&children[1]),
457            self.on.clone(),
458            self.filter.clone(),
459            &self.join_type,
460            self.null_equality,
461            self.left_sort_exprs.clone(),
462            self.right_sort_exprs.clone(),
463            self.mode,
464        )?))
465    }
466
467    fn metrics(&self) -> Option<MetricsSet> {
468        Some(self.metrics.clone_inner())
469    }
470
471    fn statistics(&self) -> Result<Statistics> {
472        // TODO stats: it is not possible in general to know the output size of joins
473        Ok(Statistics::new_unknown(&self.schema()))
474    }
475
476    fn execute(
477        &self,
478        partition: usize,
479        context: Arc<TaskContext>,
480    ) -> Result<SendableRecordBatchStream> {
481        let left_partitions = self.left.output_partitioning().partition_count();
482        let right_partitions = self.right.output_partitioning().partition_count();
483        if left_partitions != right_partitions {
484            return internal_err!(
485                "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
486                 consider using RepartitionExec"
487            );
488        }
489        // If `filter_state` and `filter` are both present, then calculate sorted
490        // filter expressions for both sides, and build an expression graph.
491        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
492            self.left_sort_exprs(),
493            self.right_sort_exprs(),
494            &self.filter,
495        ) {
496            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
497                let (left, right, graph) = prepare_sorted_exprs(
498                    filter,
499                    &self.left,
500                    &self.right,
501                    left_sort_exprs,
502                    right_sort_exprs,
503                )?;
504                (Some(left), Some(right), Some(graph))
505            }
506            // If `filter_state` or `filter` is not present, then return None
507            // for all three values:
508            _ => (None, None, None),
509        };
510
511        let (on_left, on_right) = self.on.iter().cloned().unzip();
512
513        let left_side_joiner =
514            OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
515        let right_side_joiner =
516            OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
517
518        let left_stream = self.left.execute(partition, Arc::clone(&context))?;
519
520        let right_stream = self.right.execute(partition, Arc::clone(&context))?;
521
522        let batch_size = context.session_config().batch_size();
523        let enforce_batch_size_in_joins =
524            context.session_config().enforce_batch_size_in_joins();
525
526        let reservation = Arc::new(Mutex::new(
527            MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
528                .register(context.memory_pool()),
529        ));
530        if let Some(g) = graph.as_ref() {
531            reservation.lock().try_grow(g.size())?;
532        }
533
534        if enforce_batch_size_in_joins {
535            Ok(Box::pin(SymmetricHashJoinStream {
536                left_stream,
537                right_stream,
538                schema: self.schema(),
539                filter: self.filter.clone(),
540                join_type: self.join_type,
541                random_state: self.random_state.clone(),
542                left: left_side_joiner,
543                right: right_side_joiner,
544                column_indices: self.column_indices.clone(),
545                metrics: StreamJoinMetrics::new(partition, &self.metrics),
546                graph,
547                left_sorted_filter_expr,
548                right_sorted_filter_expr,
549                null_equality: self.null_equality,
550                state: SHJStreamState::PullRight,
551                reservation,
552                batch_transformer: BatchSplitter::new(batch_size),
553            }))
554        } else {
555            Ok(Box::pin(SymmetricHashJoinStream {
556                left_stream,
557                right_stream,
558                schema: self.schema(),
559                filter: self.filter.clone(),
560                join_type: self.join_type,
561                random_state: self.random_state.clone(),
562                left: left_side_joiner,
563                right: right_side_joiner,
564                column_indices: self.column_indices.clone(),
565                metrics: StreamJoinMetrics::new(partition, &self.metrics),
566                graph,
567                left_sorted_filter_expr,
568                right_sorted_filter_expr,
569                null_equality: self.null_equality,
570                state: SHJStreamState::PullRight,
571                reservation,
572                batch_transformer: NoopBatchTransformer::new(),
573            }))
574        }
575    }
576
577    /// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done,
578    /// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan.
579    /// Otherwise, it returns None.
580    fn try_swapping_with_projection(
581        &self,
582        projection: &ProjectionExec,
583    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
584        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
585        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
586        else {
587            return Ok(None);
588        };
589
590        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
591            self.left().schema().fields().len(),
592            &projection_as_columns,
593        );
594
595        if !join_allows_pushdown(
596            &projection_as_columns,
597            &self.schema(),
598            far_right_left_col_ind,
599            far_left_right_col_ind,
600        ) {
601            return Ok(None);
602        }
603
604        let Some(new_on) = update_join_on(
605            &projection_as_columns[0..=far_right_left_col_ind as _],
606            &projection_as_columns[far_left_right_col_ind as _..],
607            self.on(),
608            self.left().schema().fields().len(),
609        ) else {
610            return Ok(None);
611        };
612
613        let new_filter = if let Some(filter) = self.filter() {
614            match update_join_filter(
615                &projection_as_columns[0..=far_right_left_col_ind as _],
616                &projection_as_columns[far_left_right_col_ind as _..],
617                filter,
618                self.left().schema().fields().len(),
619            ) {
620                Some(updated_filter) => Some(updated_filter),
621                None => return Ok(None),
622            }
623        } else {
624            None
625        };
626
627        let (new_left, new_right) = new_join_children(
628            &projection_as_columns,
629            far_right_left_col_ind,
630            far_left_right_col_ind,
631            self.left(),
632            self.right(),
633        )?;
634
635        SymmetricHashJoinExec::try_new(
636            Arc::new(new_left),
637            Arc::new(new_right),
638            new_on,
639            new_filter,
640            self.join_type(),
641            self.null_equality(),
642            self.right().output_ordering().cloned(),
643            self.left().output_ordering().cloned(),
644            self.partition_mode(),
645        )
646        .map(|e| Some(Arc::new(e) as _))
647    }
648}
649
650/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
651struct SymmetricHashJoinStream<T> {
652    /// Input streams
653    left_stream: SendableRecordBatchStream,
654    right_stream: SendableRecordBatchStream,
655    /// Input schema
656    schema: Arc<Schema>,
657    /// join filter
658    filter: Option<JoinFilter>,
659    /// type of the join
660    join_type: JoinType,
661    // left hash joiner
662    left: OneSideHashJoiner,
663    /// right hash joiner
664    right: OneSideHashJoiner,
665    /// Information of index and left / right placement of columns
666    column_indices: Vec<ColumnIndex>,
667    // Expression graph for range pruning.
668    graph: Option<ExprIntervalGraph>,
669    // Left globally sorted filter expr
670    left_sorted_filter_expr: Option<SortedFilterExpr>,
671    // Right globally sorted filter expr
672    right_sorted_filter_expr: Option<SortedFilterExpr>,
673    /// Random state used for hashing initialization
674    random_state: RandomState,
675    /// Defines the null equality for the join.
676    null_equality: NullEquality,
677    /// Metrics
678    metrics: StreamJoinMetrics,
679    /// Memory reservation
680    reservation: SharedMemoryReservation,
681    /// State machine for input execution
682    state: SHJStreamState,
683    /// Transforms the output batch before returning.
684    batch_transformer: T,
685}
686
687impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
688    for SymmetricHashJoinStream<T>
689{
690    fn schema(&self) -> SchemaRef {
691        Arc::clone(&self.schema)
692    }
693}
694
695impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> {
696    type Item = Result<RecordBatch>;
697
698    fn poll_next(
699        mut self: std::pin::Pin<&mut Self>,
700        cx: &mut Context<'_>,
701    ) -> Poll<Option<Self::Item>> {
702        self.poll_next_impl(cx)
703    }
704}
705
706/// Determine the pruning length for `buffer`.
707///
708/// This function evaluates the build side filter expression, converts the
709/// result into an array and determines the pruning length by performing a
710/// binary search on the array.
711///
712/// # Arguments
713///
714/// * `buffer`: The record batch to be pruned.
715/// * `build_side_filter_expr`: The filter expression on the build side used
716///   to determine the pruning length.
717///
718/// # Returns
719///
720/// A [Result] object that contains the pruning length. The function will return
721/// an error if
722/// - there is an issue evaluating the build side filter expression;
723/// - there is an issue converting the build side filter expression into an array
724fn determine_prune_length(
725    buffer: &RecordBatch,
726    build_side_filter_expr: &SortedFilterExpr,
727) -> Result<usize> {
728    let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
729    let interval = build_side_filter_expr.interval();
730    // Evaluate the build side filter expression and convert it into an array
731    let batch_arr = origin_sorted_expr
732        .expr
733        .evaluate(buffer)?
734        .into_array(buffer.num_rows())?;
735
736    // Get the lower or upper interval based on the sort direction
737    let target = if origin_sorted_expr.options.descending {
738        interval.upper().clone()
739    } else {
740        interval.lower().clone()
741    };
742
743    // Perform binary search on the array to determine the length of the record batch to be pruned
744    bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
745}
746
747/// This method determines if the result of the join should be produced in the final step or not.
748///
749/// # Arguments
750///
751/// * `build_side` - Enum indicating the side of the join used as the build side.
752/// * `join_type` - Enum indicating the type of join to be performed.
753///
754/// # Returns
755///
756/// A boolean indicating whether the result of the join should be produced in the final step or not.
757/// The result will be true if the build side is JoinSide::Left and the join type is one of
758/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi.
759/// If the build side is JoinSide::Right, the result will be true if the join type
760/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi.
761fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
762    if build_side == JoinSide::Left {
763        matches!(
764            join_type,
765            JoinType::Left
766                | JoinType::LeftAnti
767                | JoinType::Full
768                | JoinType::LeftSemi
769                | JoinType::LeftMark
770        )
771    } else {
772        matches!(
773            join_type,
774            JoinType::Right
775                | JoinType::RightAnti
776                | JoinType::Full
777                | JoinType::RightSemi
778                | JoinType::RightMark
779        )
780    }
781}
782
783/// Calculate indices by join type.
784///
785/// This method returns a tuple of two arrays: build and probe indices.
786/// The length of both arrays will be the same.
787///
788/// # Arguments
789///
790/// * `build_side`: Join side which defines the build side.
791/// * `prune_length`: Length of the prune data.
792/// * `visited_rows`: Hash set of visited rows of the build side.
793/// * `deleted_offset`: Deleted offset of the build side.
794/// * `join_type`: The type of join to be performed.
795///
796/// # Returns
797///
798/// A tuple of two arrays of primitive types representing the build and probe indices.
799///
800fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
801    build_side: JoinSide,
802    prune_length: usize,
803    visited_rows: &HashSet<usize>,
804    deleted_offset: usize,
805    join_type: JoinType,
806) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
807where
808    NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
809{
810    // Store the result in a tuple
811    let result = match (build_side, join_type) {
812        // For a mark join we “mark” each build‐side row with a dummy 0 in the probe‐side index
813        // if it ever matched. For example, if
814        //
815        // prune_length = 5
816        // deleted_offset = 0
817        // visited_rows = {1, 3}
818        //
819        // then we produce:
820        //
821        // build_indices = [0, 1, 2, 3, 4]
822        // probe_indices = [None, Some(0), None, Some(0), None]
823        //
824        // Example: for each build row i in [0..5):
825        //   – We always output its own index i in `build_indices`
826        //   – We output `Some(0)` in `probe_indices[i]` if row i was ever visited, else `None`
827        (JoinSide::Left, JoinType::LeftMark) => {
828            let build_indices = (0..prune_length)
829                .map(L::Native::from_usize)
830                .collect::<PrimitiveArray<L>>();
831            let probe_indices = (0..prune_length)
832                .map(|idx| {
833                    // For mark join we output a dummy index 0 to indicate the row had a match
834                    visited_rows
835                        .contains(&(idx + deleted_offset))
836                        .then_some(R::Native::from_usize(0).unwrap())
837                })
838                .collect();
839            (build_indices, probe_indices)
840        }
841        (JoinSide::Right, JoinType::RightMark) => {
842            let build_indices = (0..prune_length)
843                .map(L::Native::from_usize)
844                .collect::<PrimitiveArray<L>>();
845            let probe_indices = (0..prune_length)
846                .map(|idx| {
847                    // For mark join we output a dummy index 0 to indicate the row had a match
848                    visited_rows
849                        .contains(&(idx + deleted_offset))
850                        .then_some(R::Native::from_usize(0).unwrap())
851                })
852                .collect();
853            (build_indices, probe_indices)
854        }
855        // In the case of `Left` or `Right` join, or `Full` join, get the anti indices
856        (JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
857        | (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
858        | (_, JoinType::Full) => {
859            let build_unmatched_indices =
860                get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
861            let mut builder =
862                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
863            builder.append_nulls(build_unmatched_indices.len());
864            let probe_indices = builder.finish();
865            (build_unmatched_indices, probe_indices)
866        }
867        // In the case of `LeftSemi` or `RightSemi` join, get the semi indices
868        (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
869            let build_unmatched_indices =
870                get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
871            let mut builder =
872                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
873            builder.append_nulls(build_unmatched_indices.len());
874            let probe_indices = builder.finish();
875            (build_unmatched_indices, probe_indices)
876        }
877        // The case of other join types is not considered
878        _ => unreachable!(),
879    };
880    Ok(result)
881}
882
883/// This function produces unmatched record results based on the build side,
884/// join type and other parameters.
885///
886/// The method uses first `prune_length` rows from the build side input buffer
887/// to produce results.
888///
889/// # Arguments
890///
891/// * `output_schema` - The schema of the final output record batch.
892/// * `prune_length` - The length of the determined prune length.
893/// * `probe_schema` - The schema of the probe [RecordBatch].
894/// * `join_type` - The type of join to be performed.
895/// * `column_indices` - Indices of columns that are being joined.
896///
897/// # Returns
898///
899/// * `Option<RecordBatch>` - The final output record batch if required, otherwise [None].
900pub(crate) fn build_side_determined_results(
901    build_hash_joiner: &OneSideHashJoiner,
902    output_schema: &SchemaRef,
903    prune_length: usize,
904    probe_schema: SchemaRef,
905    join_type: JoinType,
906    column_indices: &[ColumnIndex],
907) -> Result<Option<RecordBatch>> {
908    // Check if we need to produce a result in the final output:
909    if prune_length > 0
910        && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
911    {
912        // Calculate the indices for build and probe sides based on join type and build side:
913        let (build_indices, probe_indices) = calculate_indices_by_join_type(
914            build_hash_joiner.build_side,
915            prune_length,
916            &build_hash_joiner.visited_rows,
917            build_hash_joiner.deleted_offset,
918            join_type,
919        )?;
920
921        // Create an empty probe record batch:
922        let empty_probe_batch = RecordBatch::new_empty(probe_schema);
923        // Build the final result from the indices of build and probe sides:
924        build_batch_from_indices(
925            output_schema.as_ref(),
926            &build_hash_joiner.input_buffer,
927            &empty_probe_batch,
928            &build_indices,
929            &probe_indices,
930            column_indices,
931            build_hash_joiner.build_side,
932        )
933        .map(|batch| (batch.num_rows() > 0).then_some(batch))
934    } else {
935        // If we don't need to produce a result, return None
936        Ok(None)
937    }
938}
939
940/// This method performs a join between the build side input buffer and the probe side batch.
941///
942/// # Arguments
943///
944/// * `build_hash_joiner` - Build side hash joiner
945/// * `probe_hash_joiner` - Probe side hash joiner
946/// * `schema` - A reference to the schema of the output record batch.
947/// * `join_type` - The type of join to be performed.
948/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
949/// * `filter` - An optional filter on the join condition.
950/// * `probe_batch` - The second record batch to be joined.
951/// * `column_indices` - An array of columns to be selected for the result of the join.
952/// * `random_state` - The random state for the join.
953/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
954///
955/// # Returns
956///
957/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
958/// If the join type is one of the above four, the function will return [None].
959#[allow(clippy::too_many_arguments)]
960pub(crate) fn join_with_probe_batch(
961    build_hash_joiner: &mut OneSideHashJoiner,
962    probe_hash_joiner: &mut OneSideHashJoiner,
963    schema: &SchemaRef,
964    join_type: JoinType,
965    filter: Option<&JoinFilter>,
966    probe_batch: &RecordBatch,
967    column_indices: &[ColumnIndex],
968    random_state: &RandomState,
969    null_equality: NullEquality,
970) -> Result<Option<RecordBatch>> {
971    if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
972        return Ok(None);
973    }
974    let (build_indices, probe_indices) = lookup_join_hashmap(
975        &build_hash_joiner.hashmap,
976        &build_hash_joiner.input_buffer,
977        probe_batch,
978        &build_hash_joiner.on,
979        &probe_hash_joiner.on,
980        random_state,
981        null_equality,
982        &mut build_hash_joiner.hashes_buffer,
983        Some(build_hash_joiner.deleted_offset),
984    )?;
985
986    let (build_indices, probe_indices) = if let Some(filter) = filter {
987        apply_join_filter_to_indices(
988            &build_hash_joiner.input_buffer,
989            probe_batch,
990            build_indices,
991            probe_indices,
992            filter,
993            build_hash_joiner.build_side,
994            None,
995        )?
996    } else {
997        (build_indices, probe_indices)
998    };
999
1000    if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
1001        record_visited_indices(
1002            &mut build_hash_joiner.visited_rows,
1003            build_hash_joiner.deleted_offset,
1004            &build_indices,
1005        );
1006    }
1007    if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
1008        record_visited_indices(
1009            &mut probe_hash_joiner.visited_rows,
1010            probe_hash_joiner.offset,
1011            &probe_indices,
1012        );
1013    }
1014    if matches!(
1015        join_type,
1016        JoinType::LeftAnti
1017            | JoinType::RightAnti
1018            | JoinType::LeftSemi
1019            | JoinType::LeftMark
1020            | JoinType::RightSemi
1021    ) {
1022        Ok(None)
1023    } else {
1024        build_batch_from_indices(
1025            schema,
1026            &build_hash_joiner.input_buffer,
1027            probe_batch,
1028            &build_indices,
1029            &probe_indices,
1030            column_indices,
1031            build_hash_joiner.build_side,
1032        )
1033        .map(|batch| (batch.num_rows() > 0).then_some(batch))
1034    }
1035}
1036
1037/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential
1038/// hash collisions.
1039///
1040/// # Arguments
1041///
1042/// * `build_hashmap` - hashmap collected from build side data.
1043/// * `build_batch` - Build side record batch.
1044/// * `probe_batch` - Probe side record batch.
1045/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join.
1046/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
1047/// * `random_state` - The random state for the join.
1048/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
1049/// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
1050/// * `deleted_offset` - deleted offset for build side data.
1051///
1052/// # Returns
1053///
1054/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side,
1055/// matched by join key columns.
1056#[allow(clippy::too_many_arguments)]
1057fn lookup_join_hashmap(
1058    build_hashmap: &PruningJoinHashMap,
1059    build_batch: &RecordBatch,
1060    probe_batch: &RecordBatch,
1061    build_on: &[PhysicalExprRef],
1062    probe_on: &[PhysicalExprRef],
1063    random_state: &RandomState,
1064    null_equality: NullEquality,
1065    hashes_buffer: &mut Vec<u64>,
1066    deleted_offset: Option<usize>,
1067) -> Result<(UInt64Array, UInt32Array)> {
1068    let keys_values = probe_on
1069        .iter()
1070        .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))
1071        .collect::<Result<Vec<_>>>()?;
1072    let build_join_values = build_on
1073        .iter()
1074        .map(|c| c.evaluate(build_batch)?.into_array(build_batch.num_rows()))
1075        .collect::<Result<Vec<_>>>()?;
1076
1077    hashes_buffer.clear();
1078    hashes_buffer.resize(probe_batch.num_rows(), 0);
1079    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1080
1081    // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm
1082    // will return build indices for each probe row in a reverse order as such:
1083    // Build Indices: [5, 4, 3]
1084    // Probe Indices: [1, 1, 1]
1085    //
1086    // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side.
1087    // Let's consider probe rows [0,1] as an example:
1088    //
1089    // When the probe iteration sequence is reversed, the following pairings can be derived:
1090    //
1091    // For probe row 1:
1092    //     (5, 1)
1093    //     (4, 1)
1094    //     (3, 1)
1095    //
1096    // For probe row 0:
1097    //     (5, 0)
1098    //     (4, 0)
1099    //     (3, 0)
1100    //
1101    // After reversing both sets of indices, we obtain reversed indices:
1102    //
1103    //     (3,0)
1104    //     (4,0)
1105    //     (5,0)
1106    //     (3,1)
1107    //     (4,1)
1108    //     (5,1)
1109    //
1110    // With this approach, the lexicographic order on both the probe side and the build side is preserved.
1111    let (mut matched_probe, mut matched_build) = build_hashmap.get_matched_indices(
1112        Box::new(hash_values.iter().enumerate().rev()),
1113        deleted_offset,
1114    );
1115
1116    matched_probe.reverse();
1117    matched_build.reverse();
1118
1119    let build_indices: UInt64Array = matched_build.into();
1120    let probe_indices: UInt32Array = matched_probe.into();
1121
1122    let (build_indices, probe_indices) = equal_rows_arr(
1123        &build_indices,
1124        &probe_indices,
1125        &build_join_values,
1126        &keys_values,
1127        null_equality,
1128    )?;
1129
1130    Ok((build_indices, probe_indices))
1131}
1132
1133pub struct OneSideHashJoiner {
1134    /// Build side
1135    build_side: JoinSide,
1136    /// Input record batch buffer
1137    pub input_buffer: RecordBatch,
1138    /// Columns from the side
1139    pub(crate) on: Vec<PhysicalExprRef>,
1140    /// Hashmap
1141    pub(crate) hashmap: PruningJoinHashMap,
1142    /// Reuse the hashes buffer
1143    pub(crate) hashes_buffer: Vec<u64>,
1144    /// Matched rows
1145    pub(crate) visited_rows: HashSet<usize>,
1146    /// Offset
1147    pub(crate) offset: usize,
1148    /// Deleted offset
1149    pub(crate) deleted_offset: usize,
1150}
1151
1152impl OneSideHashJoiner {
1153    pub fn size(&self) -> usize {
1154        let mut size = 0;
1155        size += size_of_val(self);
1156        size += size_of_val(&self.build_side);
1157        size += self.input_buffer.get_array_memory_size();
1158        size += size_of_val(&self.on);
1159        size += self.hashmap.size();
1160        size += self.hashes_buffer.capacity() * size_of::<u64>();
1161        size += self.visited_rows.capacity() * size_of::<usize>();
1162        size += size_of_val(&self.offset);
1163        size += size_of_val(&self.deleted_offset);
1164        size
1165    }
1166    pub fn new(
1167        build_side: JoinSide,
1168        on: Vec<PhysicalExprRef>,
1169        schema: SchemaRef,
1170    ) -> Self {
1171        Self {
1172            build_side,
1173            input_buffer: RecordBatch::new_empty(schema),
1174            on,
1175            hashmap: PruningJoinHashMap::with_capacity(0),
1176            hashes_buffer: vec![],
1177            visited_rows: HashSet::new(),
1178            offset: 0,
1179            deleted_offset: 0,
1180        }
1181    }
1182
1183    /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch.
1184    ///
1185    /// # Arguments
1186    ///
1187    /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer
1188    /// * `random_state` - The random state used to hash values
1189    ///
1190    /// # Returns
1191    ///
1192    /// Returns a [Result] encapsulating any intermediate errors.
1193    pub(crate) fn update_internal_state(
1194        &mut self,
1195        batch: &RecordBatch,
1196        random_state: &RandomState,
1197    ) -> Result<()> {
1198        // Merge the incoming batch with the existing input buffer:
1199        self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
1200        // Resize the hashes buffer to the number of rows in the incoming batch:
1201        self.hashes_buffer.resize(batch.num_rows(), 0);
1202        // Get allocation_info before adding the item
1203        // Update the hashmap with the join key values and hashes of the incoming batch:
1204        update_hash(
1205            &self.on,
1206            batch,
1207            &mut self.hashmap,
1208            self.offset,
1209            random_state,
1210            &mut self.hashes_buffer,
1211            self.deleted_offset,
1212            false,
1213        )?;
1214        Ok(())
1215    }
1216
1217    /// Calculate prune length.
1218    ///
1219    /// # Arguments
1220    ///
1221    /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression..
1222    /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression.
1223    /// * `graph` - A mutable reference to the physical expression graph.
1224    ///
1225    /// # Returns
1226    ///
1227    /// A Result object that contains the pruning length.
1228    pub(crate) fn calculate_prune_length_with_probe_batch(
1229        &mut self,
1230        build_side_sorted_filter_expr: &mut SortedFilterExpr,
1231        probe_side_sorted_filter_expr: &mut SortedFilterExpr,
1232        graph: &mut ExprIntervalGraph,
1233    ) -> Result<usize> {
1234        // Return early if the input buffer is empty:
1235        if self.input_buffer.num_rows() == 0 {
1236            return Ok(0);
1237        }
1238        // Process the build and probe side sorted filter expressions if both are present:
1239        // Collect the sorted filter expressions into a vector of (node_index, interval) tuples:
1240        let mut filter_intervals = vec![];
1241        for expr in [
1242            &build_side_sorted_filter_expr,
1243            &probe_side_sorted_filter_expr,
1244        ] {
1245            filter_intervals.push((expr.node_index(), expr.interval().clone()))
1246        }
1247        // Update the physical expression graph using the join filter intervals:
1248        graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?;
1249        // Extract the new join filter interval for the build side:
1250        let calculated_build_side_interval = filter_intervals.remove(0).1;
1251        // If the intervals have not changed, return early without pruning:
1252        if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
1253            return Ok(0);
1254        }
1255        // Update the build side interval and determine the pruning length:
1256        build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
1257
1258        determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
1259    }
1260
1261    pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
1262        // Prune the hash values:
1263        self.hashmap.prune_hash_values(
1264            prune_length,
1265            self.deleted_offset as u64,
1266            HASHMAP_SHRINK_SCALE_FACTOR,
1267        );
1268        // Remove pruned rows from the visited rows set:
1269        for row in self.deleted_offset..(self.deleted_offset + prune_length) {
1270            self.visited_rows.remove(&row);
1271        }
1272        // Update the input buffer after pruning:
1273        self.input_buffer = self
1274            .input_buffer
1275            .slice(prune_length, self.input_buffer.num_rows() - prune_length);
1276        // Increment the deleted offset:
1277        self.deleted_offset += prune_length;
1278        Ok(())
1279    }
1280}
1281
1282/// `SymmetricHashJoinStream` manages incremental join operations between two
1283/// streams. Unlike traditional join approaches that need to scan one side of
1284/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates
1285/// more dynamic join operations by working with streams as they emit data. This
1286/// approach allows for more efficient processing, particularly in scenarios
1287/// where waiting for complete data materialization is not feasible or optimal.
1288/// The trait provides a framework for handling various states of such a join
1289/// process, ensuring that join logic is efficiently executed as data becomes
1290/// available from either stream.
1291///
1292/// This implementation performs eager joins of data from two different asynchronous
1293/// streams, typically referred to as left and right streams. The implementation
1294/// provides a comprehensive set of methods to control and execute the join
1295/// process, leveraging the states defined in `SHJStreamState`. Methods are
1296/// primarily focused on asynchronously fetching data batches from each stream,
1297/// processing them, and managing transitions between various states of the join.
1298///
1299/// This implementations use a state machine approach to navigate different
1300/// stages of the join operation, handling data from both streams and determining
1301/// when the join completes.
1302///
1303/// State Transitions:
1304/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
1305///   - In `fetch_next_from_left_stream`, when fetching a batch from the left stream:
1306///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
1307///       processing the batch.
1308///     - On error (`Some(Err(e))`), the error is returned, and the state remains
1309///       unchanged.
1310///     - On no data (`None`), state changes to `LeftExhausted`, returning `Continue`
1311///       to proceed with the join process.
1312/// - From `PullRight` to `PullLeft` or `RightExhausted`:
1313///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
1314///     - If a batch is available, state changes to `PullLeft` for processing.
1315///     - On error, the error is returned without changing the state.
1316///     - If right stream is exhausted (`None`), state transitions to `RightExhausted`,
1317///       with a `Continue` result.
1318/// - Handling `RightExhausted` and `LeftExhausted`:
1319///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios
1320///     when streams are exhausted:
1321///     - They attempt to continue processing with the other stream.
1322///     - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`.
1323/// - Transition to `BothExhausted { final_result: true }`:
1324///   - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are
1325///     exhausted, indicating completion of processing and availability of final results.
1326impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
1327    /// Implements the main polling logic for the join stream.
1328    ///
1329    /// This method continuously checks the state of the join stream and
1330    /// acts accordingly by delegating the handling to appropriate sub-methods
1331    /// depending on the current state.
1332    ///
1333    /// # Arguments
1334    ///
1335    /// * `cx` - A context that facilitates cooperative non-blocking execution within a task.
1336    ///
1337    /// # Returns
1338    ///
1339    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a `RecordBatch` or None.
1340    fn poll_next_impl(
1341        &mut self,
1342        cx: &mut Context<'_>,
1343    ) -> Poll<Option<Result<RecordBatch>>> {
1344        loop {
1345            match self.batch_transformer.next() {
1346                None => {
1347                    let result = match self.state() {
1348                        SHJStreamState::PullRight => {
1349                            ready!(self.fetch_next_from_right_stream(cx))
1350                        }
1351                        SHJStreamState::PullLeft => {
1352                            ready!(self.fetch_next_from_left_stream(cx))
1353                        }
1354                        SHJStreamState::RightExhausted => {
1355                            ready!(self.handle_right_stream_end(cx))
1356                        }
1357                        SHJStreamState::LeftExhausted => {
1358                            ready!(self.handle_left_stream_end(cx))
1359                        }
1360                        SHJStreamState::BothExhausted {
1361                            final_result: false,
1362                        } => self.prepare_for_final_results_after_exhaustion(),
1363                        SHJStreamState::BothExhausted { final_result: true } => {
1364                            return Poll::Ready(None);
1365                        }
1366                    };
1367
1368                    match result? {
1369                        StatefulStreamResult::Ready(None) => {
1370                            return Poll::Ready(None);
1371                        }
1372                        StatefulStreamResult::Ready(Some(batch)) => {
1373                            self.batch_transformer.set_batch(batch);
1374                        }
1375                        _ => {}
1376                    }
1377                }
1378                Some((batch, _)) => {
1379                    self.metrics.output_batches.add(1);
1380                    return self
1381                        .metrics
1382                        .baseline_metrics
1383                        .record_poll(Poll::Ready(Some(Ok(batch))));
1384                }
1385            }
1386        }
1387    }
1388    /// Asynchronously pulls the next batch from the right stream.
1389    ///
1390    /// This default implementation checks for the next value in the right stream.
1391    /// If a batch is found, the state is switched to `PullLeft`, and the batch handling
1392    /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`.
1393    ///
1394    /// # Returns
1395    ///
1396    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1397    fn fetch_next_from_right_stream(
1398        &mut self,
1399        cx: &mut Context<'_>,
1400    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1401        match ready!(self.right_stream().poll_next_unpin(cx)) {
1402            Some(Ok(batch)) => {
1403                if batch.num_rows() == 0 {
1404                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1405                }
1406                self.set_state(SHJStreamState::PullLeft);
1407                Poll::Ready(self.process_batch_from_right(batch))
1408            }
1409            Some(Err(e)) => Poll::Ready(Err(e)),
1410            None => {
1411                self.set_state(SHJStreamState::RightExhausted);
1412                Poll::Ready(Ok(StatefulStreamResult::Continue))
1413            }
1414        }
1415    }
1416
1417    /// Asynchronously pulls the next batch from the left stream.
1418    ///
1419    /// This default implementation checks for the next value in the left stream.
1420    /// If a batch is found, the state is switched to `PullRight`, and the batch handling
1421    /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`.
1422    ///
1423    /// # Returns
1424    ///
1425    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1426    fn fetch_next_from_left_stream(
1427        &mut self,
1428        cx: &mut Context<'_>,
1429    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1430        match ready!(self.left_stream().poll_next_unpin(cx)) {
1431            Some(Ok(batch)) => {
1432                if batch.num_rows() == 0 {
1433                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1434                }
1435                self.set_state(SHJStreamState::PullRight);
1436                Poll::Ready(self.process_batch_from_left(batch))
1437            }
1438            Some(Err(e)) => Poll::Ready(Err(e)),
1439            None => {
1440                self.set_state(SHJStreamState::LeftExhausted);
1441                Poll::Ready(Ok(StatefulStreamResult::Continue))
1442            }
1443        }
1444    }
1445
1446    /// Asynchronously handles the scenario when the right stream is exhausted.
1447    ///
1448    /// In this default implementation, when the right stream is exhausted, it attempts
1449    /// to pull from the left stream. If a batch is found in the left stream, it delegates
1450    /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set
1451    /// to indicate both streams are exhausted without final results yet.
1452    ///
1453    /// # Returns
1454    ///
1455    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1456    fn handle_right_stream_end(
1457        &mut self,
1458        cx: &mut Context<'_>,
1459    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1460        match ready!(self.left_stream().poll_next_unpin(cx)) {
1461            Some(Ok(batch)) => {
1462                if batch.num_rows() == 0 {
1463                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1464                }
1465                Poll::Ready(self.process_batch_after_right_end(batch))
1466            }
1467            Some(Err(e)) => Poll::Ready(Err(e)),
1468            None => {
1469                self.set_state(SHJStreamState::BothExhausted {
1470                    final_result: false,
1471                });
1472                Poll::Ready(Ok(StatefulStreamResult::Continue))
1473            }
1474        }
1475    }
1476
1477    /// Asynchronously handles the scenario when the left stream is exhausted.
1478    ///
1479    /// When the left stream is exhausted, this default
1480    /// implementation tries to pull from the right stream and delegates the batch
1481    /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state
1482    /// is updated to indicate so.
1483    ///
1484    /// # Returns
1485    ///
1486    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1487    fn handle_left_stream_end(
1488        &mut self,
1489        cx: &mut Context<'_>,
1490    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1491        match ready!(self.right_stream().poll_next_unpin(cx)) {
1492            Some(Ok(batch)) => {
1493                if batch.num_rows() == 0 {
1494                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1495                }
1496                Poll::Ready(self.process_batch_after_left_end(batch))
1497            }
1498            Some(Err(e)) => Poll::Ready(Err(e)),
1499            None => {
1500                self.set_state(SHJStreamState::BothExhausted {
1501                    final_result: false,
1502                });
1503                Poll::Ready(Ok(StatefulStreamResult::Continue))
1504            }
1505        }
1506    }
1507
1508    /// Handles the state when both streams are exhausted and final results are yet to be produced.
1509    ///
1510    /// This default implementation switches the state to indicate both streams are
1511    /// exhausted with final results and then invokes the handling for this specific
1512    /// scenario via `process_batches_before_finalization`.
1513    ///
1514    /// # Returns
1515    ///
1516    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
1517    fn prepare_for_final_results_after_exhaustion(
1518        &mut self,
1519    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1520        self.set_state(SHJStreamState::BothExhausted { final_result: true });
1521        self.process_batches_before_finalization()
1522    }
1523
1524    fn process_batch_from_right(
1525        &mut self,
1526        batch: RecordBatch,
1527    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1528        self.perform_join_for_given_side(batch, JoinSide::Right)
1529            .map(|maybe_batch| {
1530                if maybe_batch.is_some() {
1531                    StatefulStreamResult::Ready(maybe_batch)
1532                } else {
1533                    StatefulStreamResult::Continue
1534                }
1535            })
1536    }
1537
1538    fn process_batch_from_left(
1539        &mut self,
1540        batch: RecordBatch,
1541    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1542        self.perform_join_for_given_side(batch, JoinSide::Left)
1543            .map(|maybe_batch| {
1544                if maybe_batch.is_some() {
1545                    StatefulStreamResult::Ready(maybe_batch)
1546                } else {
1547                    StatefulStreamResult::Continue
1548                }
1549            })
1550    }
1551
1552    fn process_batch_after_left_end(
1553        &mut self,
1554        right_batch: RecordBatch,
1555    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1556        self.process_batch_from_right(right_batch)
1557    }
1558
1559    fn process_batch_after_right_end(
1560        &mut self,
1561        left_batch: RecordBatch,
1562    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1563        self.process_batch_from_left(left_batch)
1564    }
1565
1566    fn process_batches_before_finalization(
1567        &mut self,
1568    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1569        // Get the left side results:
1570        let left_result = build_side_determined_results(
1571            &self.left,
1572            &self.schema,
1573            self.left.input_buffer.num_rows(),
1574            self.right.input_buffer.schema(),
1575            self.join_type,
1576            &self.column_indices,
1577        )?;
1578        // Get the right side results:
1579        let right_result = build_side_determined_results(
1580            &self.right,
1581            &self.schema,
1582            self.right.input_buffer.num_rows(),
1583            self.left.input_buffer.schema(),
1584            self.join_type,
1585            &self.column_indices,
1586        )?;
1587
1588        // Combine the left and right results:
1589        let result = combine_two_batches(&self.schema, left_result, right_result)?;
1590
1591        // Return the result:
1592        if result.is_some() {
1593            return Ok(StatefulStreamResult::Ready(result));
1594        }
1595        Ok(StatefulStreamResult::Continue)
1596    }
1597
1598    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
1599        &mut self.right_stream
1600    }
1601
1602    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
1603        &mut self.left_stream
1604    }
1605
1606    fn set_state(&mut self, state: SHJStreamState) {
1607        self.state = state;
1608    }
1609
1610    fn state(&mut self) -> SHJStreamState {
1611        self.state.clone()
1612    }
1613
1614    fn size(&self) -> usize {
1615        let mut size = 0;
1616        size += size_of_val(&self.schema);
1617        size += size_of_val(&self.filter);
1618        size += size_of_val(&self.join_type);
1619        size += self.left.size();
1620        size += self.right.size();
1621        size += size_of_val(&self.column_indices);
1622        size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
1623        size += size_of_val(&self.left_sorted_filter_expr);
1624        size += size_of_val(&self.right_sorted_filter_expr);
1625        size += size_of_val(&self.random_state);
1626        size += size_of_val(&self.null_equality);
1627        size += size_of_val(&self.metrics);
1628        size
1629    }
1630
1631    /// Performs a join operation for the specified `probe_side` (either left or right).
1632    /// This function:
1633    /// 1. Determines which side is the probe and which is the build side.
1634    /// 2. Updates metrics based on the batch that was polled.
1635    /// 3. Executes the join with the given `probe_batch`.
1636    /// 4. Optionally computes anti-join results if all conditions are met.
1637    /// 5. Combines the results and returns a combined batch or `None` if no batch was produced.
1638    fn perform_join_for_given_side(
1639        &mut self,
1640        probe_batch: RecordBatch,
1641        probe_side: JoinSide,
1642    ) -> Result<Option<RecordBatch>> {
1643        let (
1644            probe_hash_joiner,
1645            build_hash_joiner,
1646            probe_side_sorted_filter_expr,
1647            build_side_sorted_filter_expr,
1648            probe_side_metrics,
1649        ) = if probe_side.eq(&JoinSide::Left) {
1650            (
1651                &mut self.left,
1652                &mut self.right,
1653                &mut self.left_sorted_filter_expr,
1654                &mut self.right_sorted_filter_expr,
1655                &mut self.metrics.left,
1656            )
1657        } else {
1658            (
1659                &mut self.right,
1660                &mut self.left,
1661                &mut self.right_sorted_filter_expr,
1662                &mut self.left_sorted_filter_expr,
1663                &mut self.metrics.right,
1664            )
1665        };
1666        // Update the metrics for the stream that was polled:
1667        probe_side_metrics.input_batches.add(1);
1668        probe_side_metrics.input_rows.add(probe_batch.num_rows());
1669        // Update the internal state of the hash joiner for the build side:
1670        probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?;
1671        // Join the two sides:
1672        let equal_result = join_with_probe_batch(
1673            build_hash_joiner,
1674            probe_hash_joiner,
1675            &self.schema,
1676            self.join_type,
1677            self.filter.as_ref(),
1678            &probe_batch,
1679            &self.column_indices,
1680            &self.random_state,
1681            self.null_equality,
1682        )?;
1683        // Increment the offset for the probe hash joiner:
1684        probe_hash_joiner.offset += probe_batch.num_rows();
1685
1686        let anti_result = if let (
1687            Some(build_side_sorted_filter_expr),
1688            Some(probe_side_sorted_filter_expr),
1689            Some(graph),
1690        ) = (
1691            build_side_sorted_filter_expr.as_mut(),
1692            probe_side_sorted_filter_expr.as_mut(),
1693            self.graph.as_mut(),
1694        ) {
1695            // Calculate filter intervals:
1696            calculate_filter_expr_intervals(
1697                &build_hash_joiner.input_buffer,
1698                build_side_sorted_filter_expr,
1699                &probe_batch,
1700                probe_side_sorted_filter_expr,
1701            )?;
1702            let prune_length = build_hash_joiner
1703                .calculate_prune_length_with_probe_batch(
1704                    build_side_sorted_filter_expr,
1705                    probe_side_sorted_filter_expr,
1706                    graph,
1707                )?;
1708            let result = build_side_determined_results(
1709                build_hash_joiner,
1710                &self.schema,
1711                prune_length,
1712                probe_batch.schema(),
1713                self.join_type,
1714                &self.column_indices,
1715            )?;
1716            build_hash_joiner.prune_internal_state(prune_length)?;
1717            result
1718        } else {
1719            None
1720        };
1721
1722        // Combine results:
1723        let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
1724        let capacity = self.size();
1725        self.metrics.stream_memory_usage.set(capacity);
1726        self.reservation.lock().try_resize(capacity)?;
1727        Ok(result)
1728    }
1729}
1730
1731/// Represents the various states of an symmetric hash join stream operation.
1732///
1733/// This enum is used to track the current state of streaming during a join
1734/// operation. It provides indicators as to which side of the join needs to be
1735/// pulled next or if one (or both) sides have been exhausted. This allows
1736/// for efficient management of resources and optimal performance during the
1737/// join process.
1738#[derive(Clone, Debug)]
1739pub enum SHJStreamState {
1740    /// Indicates that the next step should pull from the right side of the join.
1741    PullRight,
1742
1743    /// Indicates that the next step should pull from the left side of the join.
1744    PullLeft,
1745
1746    /// State representing that the right side of the join has been fully processed.
1747    RightExhausted,
1748
1749    /// State representing that the left side of the join has been fully processed.
1750    LeftExhausted,
1751
1752    /// Represents a state where both sides of the join are exhausted.
1753    ///
1754    /// The `final_result` field indicates whether the join operation has
1755    /// produced a final result or not.
1756    BothExhausted { final_result: bool },
1757}
1758
1759#[cfg(test)]
1760mod tests {
1761    use std::collections::HashMap;
1762    use std::sync::{LazyLock, Mutex};
1763
1764    use super::*;
1765    use crate::joins::test_utils::{
1766        build_sides_record_batches, compare_batches, complicated_filter,
1767        create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
1768        join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
1769        partitioned_sym_join_with_filter, split_record_batches,
1770    };
1771
1772    use arrow::compute::SortOptions;
1773    use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
1774    use datafusion_common::ScalarValue;
1775    use datafusion_execution::config::SessionConfig;
1776    use datafusion_expr::Operator;
1777    use datafusion_physical_expr::expressions::{binary, col, lit, Column};
1778    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
1779
1780    use rstest::*;
1781
1782    const TABLE_SIZE: i32 = 30;
1783
1784    type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size)
1785    type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); // (left, right)
1786
1787    // Cache for storing tables
1788    static TABLE_CACHE: LazyLock<Mutex<HashMap<TableKey, TableValue>>> =
1789        LazyLock::new(|| Mutex::new(HashMap::new()));
1790
1791    fn get_or_create_table(
1792        cardinality: (i32, i32),
1793        batch_size: usize,
1794    ) -> Result<TableValue> {
1795        {
1796            let cache = TABLE_CACHE.lock().unwrap();
1797            if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
1798                return Ok(table.clone());
1799            }
1800        }
1801
1802        // If not, create the table
1803        let (left_batch, right_batch) =
1804            build_sides_record_batches(TABLE_SIZE, cardinality)?;
1805
1806        let (left_partition, right_partition) = (
1807            split_record_batches(&left_batch, batch_size)?,
1808            split_record_batches(&right_batch, batch_size)?,
1809        );
1810
1811        // Lock the cache again and store the table
1812        let mut cache = TABLE_CACHE.lock().unwrap();
1813
1814        // Store the table in the cache
1815        cache.insert(
1816            (cardinality.0, cardinality.1, batch_size),
1817            (left_partition.clone(), right_partition.clone()),
1818        );
1819
1820        Ok((left_partition, right_partition))
1821    }
1822
1823    pub async fn experiment(
1824        left: Arc<dyn ExecutionPlan>,
1825        right: Arc<dyn ExecutionPlan>,
1826        filter: Option<JoinFilter>,
1827        join_type: JoinType,
1828        on: JoinOn,
1829        task_ctx: Arc<TaskContext>,
1830    ) -> Result<()> {
1831        let first_batches = partitioned_sym_join_with_filter(
1832            Arc::clone(&left),
1833            Arc::clone(&right),
1834            on.clone(),
1835            filter.clone(),
1836            &join_type,
1837            NullEquality::NullEqualsNothing,
1838            Arc::clone(&task_ctx),
1839        )
1840        .await?;
1841        let second_batches = partitioned_hash_join_with_filter(
1842            left,
1843            right,
1844            on,
1845            filter,
1846            &join_type,
1847            NullEquality::NullEqualsNothing,
1848            task_ctx,
1849        )
1850        .await?;
1851        compare_batches(&first_batches, &second_batches);
1852        Ok(())
1853    }
1854
1855    #[rstest]
1856    #[tokio::test(flavor = "multi_thread")]
1857    async fn complex_join_all_one_ascending_numeric(
1858        #[values(
1859            JoinType::Inner,
1860            JoinType::Left,
1861            JoinType::Right,
1862            JoinType::RightSemi,
1863            JoinType::LeftSemi,
1864            JoinType::LeftAnti,
1865            JoinType::LeftMark,
1866            JoinType::RightAnti,
1867            JoinType::Full
1868        )]
1869        join_type: JoinType,
1870        #[values(
1871        (4, 5),
1872        (12, 17),
1873        )]
1874        cardinality: (i32, i32),
1875    ) -> Result<()> {
1876        // a + b > c + 10 AND a + b < c + 100
1877        let task_ctx = Arc::new(TaskContext::default());
1878
1879        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
1880
1881        let left_schema = &left_partition[0].schema();
1882        let right_schema = &right_partition[0].schema();
1883
1884        let left_sorted = [PhysicalSortExpr {
1885            expr: binary(
1886                col("la1", left_schema)?,
1887                Operator::Plus,
1888                col("la2", left_schema)?,
1889                left_schema,
1890            )?,
1891            options: SortOptions::default(),
1892        }]
1893        .into();
1894        let right_sorted = [PhysicalSortExpr {
1895            expr: col("ra1", right_schema)?,
1896            options: SortOptions::default(),
1897        }]
1898        .into();
1899        let (left, right) = create_memory_table(
1900            left_partition,
1901            right_partition,
1902            vec![left_sorted],
1903            vec![right_sorted],
1904        )?;
1905
1906        let on = vec![(
1907            binary(
1908                col("lc1", left_schema)?,
1909                Operator::Plus,
1910                lit(ScalarValue::Int32(Some(1))),
1911                left_schema,
1912            )?,
1913            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1914        )];
1915
1916        let intermediate_schema = Schema::new(vec![
1917            Field::new("0", DataType::Int32, true),
1918            Field::new("1", DataType::Int32, true),
1919            Field::new("2", DataType::Int32, true),
1920        ]);
1921        let filter_expr = complicated_filter(&intermediate_schema)?;
1922        let column_indices = vec![
1923            ColumnIndex {
1924                index: left_schema.index_of("la1")?,
1925                side: JoinSide::Left,
1926            },
1927            ColumnIndex {
1928                index: left_schema.index_of("la2")?,
1929                side: JoinSide::Left,
1930            },
1931            ColumnIndex {
1932                index: right_schema.index_of("ra1")?,
1933                side: JoinSide::Right,
1934            },
1935        ];
1936        let filter =
1937            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1938
1939        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1940        Ok(())
1941    }
1942
1943    #[rstest]
1944    #[tokio::test(flavor = "multi_thread")]
1945    async fn join_all_one_ascending_numeric(
1946        #[values(
1947            JoinType::Inner,
1948            JoinType::Left,
1949            JoinType::Right,
1950            JoinType::RightSemi,
1951            JoinType::LeftSemi,
1952            JoinType::LeftAnti,
1953            JoinType::LeftMark,
1954            JoinType::RightAnti,
1955            JoinType::Full
1956        )]
1957        join_type: JoinType,
1958        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1959    ) -> Result<()> {
1960        let task_ctx = Arc::new(TaskContext::default());
1961        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1962
1963        let left_schema = &left_partition[0].schema();
1964        let right_schema = &right_partition[0].schema();
1965
1966        let left_sorted = [PhysicalSortExpr {
1967            expr: col("la1", left_schema)?,
1968            options: SortOptions::default(),
1969        }]
1970        .into();
1971        let right_sorted = [PhysicalSortExpr {
1972            expr: col("ra1", right_schema)?,
1973            options: SortOptions::default(),
1974        }]
1975        .into();
1976        let (left, right) = create_memory_table(
1977            left_partition,
1978            right_partition,
1979            vec![left_sorted],
1980            vec![right_sorted],
1981        )?;
1982
1983        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
1984
1985        let intermediate_schema = Schema::new(vec![
1986            Field::new("left", DataType::Int32, true),
1987            Field::new("right", DataType::Int32, true),
1988        ]);
1989        let filter_expr = join_expr_tests_fixture_i32(
1990            case_expr,
1991            col("left", &intermediate_schema)?,
1992            col("right", &intermediate_schema)?,
1993        );
1994        let column_indices = vec![
1995            ColumnIndex {
1996                index: 0,
1997                side: JoinSide::Left,
1998            },
1999            ColumnIndex {
2000                index: 0,
2001                side: JoinSide::Right,
2002            },
2003        ];
2004        let filter =
2005            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2006
2007        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2008        Ok(())
2009    }
2010
2011    #[rstest]
2012    #[tokio::test(flavor = "multi_thread")]
2013    async fn join_without_sort_information(
2014        #[values(
2015            JoinType::Inner,
2016            JoinType::Left,
2017            JoinType::Right,
2018            JoinType::RightSemi,
2019            JoinType::LeftSemi,
2020            JoinType::LeftAnti,
2021            JoinType::LeftMark,
2022            JoinType::RightAnti,
2023            JoinType::Full
2024        )]
2025        join_type: JoinType,
2026        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2027    ) -> Result<()> {
2028        let task_ctx = Arc::new(TaskContext::default());
2029        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
2030
2031        let left_schema = &left_partition[0].schema();
2032        let right_schema = &right_partition[0].schema();
2033        let (left, right) =
2034            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2035
2036        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2037
2038        let intermediate_schema = Schema::new(vec![
2039            Field::new("left", DataType::Int32, true),
2040            Field::new("right", DataType::Int32, true),
2041        ]);
2042        let filter_expr = join_expr_tests_fixture_i32(
2043            case_expr,
2044            col("left", &intermediate_schema)?,
2045            col("right", &intermediate_schema)?,
2046        );
2047        let column_indices = vec![
2048            ColumnIndex {
2049                index: 5,
2050                side: JoinSide::Left,
2051            },
2052            ColumnIndex {
2053                index: 5,
2054                side: JoinSide::Right,
2055            },
2056        ];
2057        let filter =
2058            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2059
2060        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2061        Ok(())
2062    }
2063
2064    #[rstest]
2065    #[tokio::test(flavor = "multi_thread")]
2066    async fn join_without_filter(
2067        #[values(
2068            JoinType::Inner,
2069            JoinType::Left,
2070            JoinType::Right,
2071            JoinType::RightSemi,
2072            JoinType::LeftSemi,
2073            JoinType::LeftAnti,
2074            JoinType::LeftMark,
2075            JoinType::RightAnti,
2076            JoinType::Full
2077        )]
2078        join_type: JoinType,
2079    ) -> Result<()> {
2080        let task_ctx = Arc::new(TaskContext::default());
2081        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2082        let left_schema = &left_partition[0].schema();
2083        let right_schema = &right_partition[0].schema();
2084        let (left, right) =
2085            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2086
2087        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2088        experiment(left, right, None, join_type, on, task_ctx).await?;
2089        Ok(())
2090    }
2091
2092    #[rstest]
2093    #[tokio::test(flavor = "multi_thread")]
2094    async fn join_all_one_descending_numeric_particular(
2095        #[values(
2096            JoinType::Inner,
2097            JoinType::Left,
2098            JoinType::Right,
2099            JoinType::RightSemi,
2100            JoinType::LeftSemi,
2101            JoinType::LeftAnti,
2102            JoinType::LeftMark,
2103            JoinType::RightAnti,
2104            JoinType::Full
2105        )]
2106        join_type: JoinType,
2107        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2108    ) -> Result<()> {
2109        let task_ctx = Arc::new(TaskContext::default());
2110        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2111        let left_schema = &left_partition[0].schema();
2112        let right_schema = &right_partition[0].schema();
2113        let left_sorted = [PhysicalSortExpr {
2114            expr: col("la1_des", left_schema)?,
2115            options: SortOptions {
2116                descending: true,
2117                nulls_first: true,
2118            },
2119        }]
2120        .into();
2121        let right_sorted = [PhysicalSortExpr {
2122            expr: col("ra1_des", right_schema)?,
2123            options: SortOptions {
2124                descending: true,
2125                nulls_first: true,
2126            },
2127        }]
2128        .into();
2129        let (left, right) = create_memory_table(
2130            left_partition,
2131            right_partition,
2132            vec![left_sorted],
2133            vec![right_sorted],
2134        )?;
2135
2136        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2137
2138        let intermediate_schema = Schema::new(vec![
2139            Field::new("left", DataType::Int32, true),
2140            Field::new("right", DataType::Int32, true),
2141        ]);
2142        let filter_expr = join_expr_tests_fixture_i32(
2143            case_expr,
2144            col("left", &intermediate_schema)?,
2145            col("right", &intermediate_schema)?,
2146        );
2147        let column_indices = vec![
2148            ColumnIndex {
2149                index: 5,
2150                side: JoinSide::Left,
2151            },
2152            ColumnIndex {
2153                index: 5,
2154                side: JoinSide::Right,
2155            },
2156        ];
2157        let filter =
2158            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2159
2160        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2161        Ok(())
2162    }
2163
2164    #[tokio::test(flavor = "multi_thread")]
2165    async fn build_null_columns_first() -> Result<()> {
2166        let join_type = JoinType::Full;
2167        let case_expr = 1;
2168        let session_config = SessionConfig::new().with_repartition_joins(false);
2169        let task_ctx = TaskContext::default().with_session_config(session_config);
2170        let task_ctx = Arc::new(task_ctx);
2171        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2172        let left_schema = &left_partition[0].schema();
2173        let right_schema = &right_partition[0].schema();
2174        let left_sorted = [PhysicalSortExpr {
2175            expr: col("l_asc_null_first", left_schema)?,
2176            options: SortOptions {
2177                descending: false,
2178                nulls_first: true,
2179            },
2180        }]
2181        .into();
2182        let right_sorted = [PhysicalSortExpr {
2183            expr: col("r_asc_null_first", right_schema)?,
2184            options: SortOptions {
2185                descending: false,
2186                nulls_first: true,
2187            },
2188        }]
2189        .into();
2190        let (left, right) = create_memory_table(
2191            left_partition,
2192            right_partition,
2193            vec![left_sorted],
2194            vec![right_sorted],
2195        )?;
2196
2197        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2198
2199        let intermediate_schema = Schema::new(vec![
2200            Field::new("left", DataType::Int32, true),
2201            Field::new("right", DataType::Int32, true),
2202        ]);
2203        let filter_expr = join_expr_tests_fixture_i32(
2204            case_expr,
2205            col("left", &intermediate_schema)?,
2206            col("right", &intermediate_schema)?,
2207        );
2208        let column_indices = vec![
2209            ColumnIndex {
2210                index: 6,
2211                side: JoinSide::Left,
2212            },
2213            ColumnIndex {
2214                index: 6,
2215                side: JoinSide::Right,
2216            },
2217        ];
2218        let filter =
2219            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2220        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2221        Ok(())
2222    }
2223
2224    #[tokio::test(flavor = "multi_thread")]
2225    async fn build_null_columns_last() -> Result<()> {
2226        let join_type = JoinType::Full;
2227        let case_expr = 1;
2228        let session_config = SessionConfig::new().with_repartition_joins(false);
2229        let task_ctx = TaskContext::default().with_session_config(session_config);
2230        let task_ctx = Arc::new(task_ctx);
2231        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2232
2233        let left_schema = &left_partition[0].schema();
2234        let right_schema = &right_partition[0].schema();
2235        let left_sorted = [PhysicalSortExpr {
2236            expr: col("l_asc_null_last", left_schema)?,
2237            options: SortOptions {
2238                descending: false,
2239                nulls_first: false,
2240            },
2241        }]
2242        .into();
2243        let right_sorted = [PhysicalSortExpr {
2244            expr: col("r_asc_null_last", right_schema)?,
2245            options: SortOptions {
2246                descending: false,
2247                nulls_first: false,
2248            },
2249        }]
2250        .into();
2251        let (left, right) = create_memory_table(
2252            left_partition,
2253            right_partition,
2254            vec![left_sorted],
2255            vec![right_sorted],
2256        )?;
2257
2258        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2259
2260        let intermediate_schema = Schema::new(vec![
2261            Field::new("left", DataType::Int32, true),
2262            Field::new("right", DataType::Int32, true),
2263        ]);
2264        let filter_expr = join_expr_tests_fixture_i32(
2265            case_expr,
2266            col("left", &intermediate_schema)?,
2267            col("right", &intermediate_schema)?,
2268        );
2269        let column_indices = vec![
2270            ColumnIndex {
2271                index: 7,
2272                side: JoinSide::Left,
2273            },
2274            ColumnIndex {
2275                index: 7,
2276                side: JoinSide::Right,
2277            },
2278        ];
2279        let filter =
2280            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2281
2282        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2283        Ok(())
2284    }
2285
2286    #[tokio::test(flavor = "multi_thread")]
2287    async fn build_null_columns_first_descending() -> Result<()> {
2288        let join_type = JoinType::Full;
2289        let cardinality = (10, 11);
2290        let case_expr = 1;
2291        let session_config = SessionConfig::new().with_repartition_joins(false);
2292        let task_ctx = TaskContext::default().with_session_config(session_config);
2293        let task_ctx = Arc::new(task_ctx);
2294        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2295
2296        let left_schema = &left_partition[0].schema();
2297        let right_schema = &right_partition[0].schema();
2298        let left_sorted = [PhysicalSortExpr {
2299            expr: col("l_desc_null_first", left_schema)?,
2300            options: SortOptions {
2301                descending: true,
2302                nulls_first: true,
2303            },
2304        }]
2305        .into();
2306        let right_sorted = [PhysicalSortExpr {
2307            expr: col("r_desc_null_first", right_schema)?,
2308            options: SortOptions {
2309                descending: true,
2310                nulls_first: true,
2311            },
2312        }]
2313        .into();
2314        let (left, right) = create_memory_table(
2315            left_partition,
2316            right_partition,
2317            vec![left_sorted],
2318            vec![right_sorted],
2319        )?;
2320
2321        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2322
2323        let intermediate_schema = Schema::new(vec![
2324            Field::new("left", DataType::Int32, true),
2325            Field::new("right", DataType::Int32, true),
2326        ]);
2327        let filter_expr = join_expr_tests_fixture_i32(
2328            case_expr,
2329            col("left", &intermediate_schema)?,
2330            col("right", &intermediate_schema)?,
2331        );
2332        let column_indices = vec![
2333            ColumnIndex {
2334                index: 8,
2335                side: JoinSide::Left,
2336            },
2337            ColumnIndex {
2338                index: 8,
2339                side: JoinSide::Right,
2340            },
2341        ];
2342        let filter =
2343            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2344
2345        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2346        Ok(())
2347    }
2348
2349    #[tokio::test(flavor = "multi_thread")]
2350    async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
2351        let cardinality = (3, 4);
2352        let join_type = JoinType::Full;
2353
2354        // a + b > c + 10 AND a + b < c + 100
2355        let session_config = SessionConfig::new().with_repartition_joins(false);
2356        let task_ctx = TaskContext::default().with_session_config(session_config);
2357        let task_ctx = Arc::new(task_ctx);
2358        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2359
2360        let left_schema = &left_partition[0].schema();
2361        let right_schema = &right_partition[0].schema();
2362        let left_sorted = [PhysicalSortExpr {
2363            expr: col("la1", left_schema)?,
2364            options: SortOptions::default(),
2365        }]
2366        .into();
2367        let right_sorted = [PhysicalSortExpr {
2368            expr: col("ra1", right_schema)?,
2369            options: SortOptions::default(),
2370        }]
2371        .into();
2372        let (left, right) = create_memory_table(
2373            left_partition,
2374            right_partition,
2375            vec![left_sorted],
2376            vec![right_sorted],
2377        )?;
2378
2379        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2380
2381        let intermediate_schema = Schema::new(vec![
2382            Field::new("0", DataType::Int32, true),
2383            Field::new("1", DataType::Int32, true),
2384            Field::new("2", DataType::Int32, true),
2385        ]);
2386        let filter_expr = complicated_filter(&intermediate_schema)?;
2387        let column_indices = vec![
2388            ColumnIndex {
2389                index: 0,
2390                side: JoinSide::Left,
2391            },
2392            ColumnIndex {
2393                index: 4,
2394                side: JoinSide::Left,
2395            },
2396            ColumnIndex {
2397                index: 0,
2398                side: JoinSide::Right,
2399            },
2400        ];
2401        let filter =
2402            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2403
2404        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2405        Ok(())
2406    }
2407
2408    #[tokio::test(flavor = "multi_thread")]
2409    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
2410        let cardinality = (3, 4);
2411        let join_type = JoinType::Full;
2412
2413        // a + b > c + 10 AND a + b < c + 100
2414        let config = SessionConfig::new().with_repartition_joins(false);
2415        // let session_ctx = SessionContext::with_config(config);
2416        // let task_ctx = session_ctx.task_ctx();
2417        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
2418        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2419        let left_schema = &left_partition[0].schema();
2420        let right_schema = &right_partition[0].schema();
2421        let left_sorted = vec![
2422            [PhysicalSortExpr {
2423                expr: col("la1", left_schema)?,
2424                options: SortOptions::default(),
2425            }]
2426            .into(),
2427            [PhysicalSortExpr {
2428                expr: col("la2", left_schema)?,
2429                options: SortOptions::default(),
2430            }]
2431            .into(),
2432        ];
2433
2434        let right_sorted = [PhysicalSortExpr {
2435            expr: col("ra1", right_schema)?,
2436            options: SortOptions::default(),
2437        }]
2438        .into();
2439
2440        let (left, right) = create_memory_table(
2441            left_partition,
2442            right_partition,
2443            left_sorted,
2444            vec![right_sorted],
2445        )?;
2446
2447        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2448
2449        let intermediate_schema = Schema::new(vec![
2450            Field::new("0", DataType::Int32, true),
2451            Field::new("1", DataType::Int32, true),
2452            Field::new("2", DataType::Int32, true),
2453        ]);
2454        let filter_expr = complicated_filter(&intermediate_schema)?;
2455        let column_indices = vec![
2456            ColumnIndex {
2457                index: 0,
2458                side: JoinSide::Left,
2459            },
2460            ColumnIndex {
2461                index: 4,
2462                side: JoinSide::Left,
2463            },
2464            ColumnIndex {
2465                index: 0,
2466                side: JoinSide::Right,
2467            },
2468        ];
2469        let filter =
2470            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2471
2472        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2473        Ok(())
2474    }
2475
2476    #[rstest]
2477    #[tokio::test(flavor = "multi_thread")]
2478    async fn testing_with_temporal_columns(
2479        #[values(
2480            JoinType::Inner,
2481            JoinType::Left,
2482            JoinType::Right,
2483            JoinType::RightSemi,
2484            JoinType::LeftSemi,
2485            JoinType::LeftAnti,
2486            JoinType::LeftMark,
2487            JoinType::RightAnti,
2488            JoinType::Full
2489        )]
2490        join_type: JoinType,
2491        #[values(
2492            (4, 5),
2493            (12, 17),
2494        )]
2495        cardinality: (i32, i32),
2496        #[values(0, 1, 2)] case_expr: usize,
2497    ) -> Result<()> {
2498        let session_config = SessionConfig::new().with_repartition_joins(false);
2499        let task_ctx = TaskContext::default().with_session_config(session_config);
2500        let task_ctx = Arc::new(task_ctx);
2501        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2502
2503        let left_schema = &left_partition[0].schema();
2504        let right_schema = &right_partition[0].schema();
2505        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2506        let left_sorted = [PhysicalSortExpr {
2507            expr: col("lt1", left_schema)?,
2508            options: SortOptions {
2509                descending: false,
2510                nulls_first: true,
2511            },
2512        }]
2513        .into();
2514        let right_sorted = [PhysicalSortExpr {
2515            expr: col("rt1", right_schema)?,
2516            options: SortOptions {
2517                descending: false,
2518                nulls_first: true,
2519            },
2520        }]
2521        .into();
2522        let (left, right) = create_memory_table(
2523            left_partition,
2524            right_partition,
2525            vec![left_sorted],
2526            vec![right_sorted],
2527        )?;
2528        let intermediate_schema = Schema::new(vec![
2529            Field::new(
2530                "left",
2531                DataType::Timestamp(TimeUnit::Millisecond, None),
2532                false,
2533            ),
2534            Field::new(
2535                "right",
2536                DataType::Timestamp(TimeUnit::Millisecond, None),
2537                false,
2538            ),
2539        ]);
2540        let filter_expr = join_expr_tests_fixture_temporal(
2541            case_expr,
2542            col("left", &intermediate_schema)?,
2543            col("right", &intermediate_schema)?,
2544            &intermediate_schema,
2545        )?;
2546        let column_indices = vec![
2547            ColumnIndex {
2548                index: 3,
2549                side: JoinSide::Left,
2550            },
2551            ColumnIndex {
2552                index: 3,
2553                side: JoinSide::Right,
2554            },
2555        ];
2556        let filter =
2557            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2558        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2559        Ok(())
2560    }
2561
2562    #[rstest]
2563    #[tokio::test(flavor = "multi_thread")]
2564    async fn test_with_interval_columns(
2565        #[values(
2566            JoinType::Inner,
2567            JoinType::Left,
2568            JoinType::Right,
2569            JoinType::RightSemi,
2570            JoinType::LeftSemi,
2571            JoinType::LeftAnti,
2572            JoinType::LeftMark,
2573            JoinType::RightAnti,
2574            JoinType::Full
2575        )]
2576        join_type: JoinType,
2577        #[values(
2578            (4, 5),
2579            (12, 17),
2580        )]
2581        cardinality: (i32, i32),
2582    ) -> Result<()> {
2583        let session_config = SessionConfig::new().with_repartition_joins(false);
2584        let task_ctx = TaskContext::default().with_session_config(session_config);
2585        let task_ctx = Arc::new(task_ctx);
2586        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2587
2588        let left_schema = &left_partition[0].schema();
2589        let right_schema = &right_partition[0].schema();
2590        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2591        let left_sorted = [PhysicalSortExpr {
2592            expr: col("li1", left_schema)?,
2593            options: SortOptions {
2594                descending: false,
2595                nulls_first: true,
2596            },
2597        }]
2598        .into();
2599        let right_sorted = [PhysicalSortExpr {
2600            expr: col("ri1", right_schema)?,
2601            options: SortOptions {
2602                descending: false,
2603                nulls_first: true,
2604            },
2605        }]
2606        .into();
2607        let (left, right) = create_memory_table(
2608            left_partition,
2609            right_partition,
2610            vec![left_sorted],
2611            vec![right_sorted],
2612        )?;
2613        let intermediate_schema = Schema::new(vec![
2614            Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
2615            Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
2616        ]);
2617        let filter_expr = join_expr_tests_fixture_temporal(
2618            0,
2619            col("left", &intermediate_schema)?,
2620            col("right", &intermediate_schema)?,
2621            &intermediate_schema,
2622        )?;
2623        let column_indices = vec![
2624            ColumnIndex {
2625                index: 9,
2626                side: JoinSide::Left,
2627            },
2628            ColumnIndex {
2629                index: 9,
2630                side: JoinSide::Right,
2631            },
2632        ];
2633        let filter =
2634            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2635        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2636
2637        Ok(())
2638    }
2639
2640    #[rstest]
2641    #[tokio::test(flavor = "multi_thread")]
2642    async fn testing_ascending_float_pruning(
2643        #[values(
2644            JoinType::Inner,
2645            JoinType::Left,
2646            JoinType::Right,
2647            JoinType::RightSemi,
2648            JoinType::LeftSemi,
2649            JoinType::LeftAnti,
2650            JoinType::LeftMark,
2651            JoinType::RightAnti,
2652            JoinType::Full
2653        )]
2654        join_type: JoinType,
2655        #[values(
2656            (4, 5),
2657            (12, 17),
2658        )]
2659        cardinality: (i32, i32),
2660        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2661    ) -> Result<()> {
2662        let session_config = SessionConfig::new().with_repartition_joins(false);
2663        let task_ctx = TaskContext::default().with_session_config(session_config);
2664        let task_ctx = Arc::new(task_ctx);
2665        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2666
2667        let left_schema = &left_partition[0].schema();
2668        let right_schema = &right_partition[0].schema();
2669        let left_sorted = [PhysicalSortExpr {
2670            expr: col("l_float", left_schema)?,
2671            options: SortOptions::default(),
2672        }]
2673        .into();
2674        let right_sorted = [PhysicalSortExpr {
2675            expr: col("r_float", right_schema)?,
2676            options: SortOptions::default(),
2677        }]
2678        .into();
2679        let (left, right) = create_memory_table(
2680            left_partition,
2681            right_partition,
2682            vec![left_sorted],
2683            vec![right_sorted],
2684        )?;
2685
2686        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2687
2688        let intermediate_schema = Schema::new(vec![
2689            Field::new("left", DataType::Float64, true),
2690            Field::new("right", DataType::Float64, true),
2691        ]);
2692        let filter_expr = join_expr_tests_fixture_f64(
2693            case_expr,
2694            col("left", &intermediate_schema)?,
2695            col("right", &intermediate_schema)?,
2696        );
2697        let column_indices = vec![
2698            ColumnIndex {
2699                index: 10, // l_float
2700                side: JoinSide::Left,
2701            },
2702            ColumnIndex {
2703                index: 10, // r_float
2704                side: JoinSide::Right,
2705            },
2706        ];
2707        let filter =
2708            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2709
2710        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2711        Ok(())
2712    }
2713}