Skip to main content

datafusion_physical_plan/joins/
symmetric_hash_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the symmetric hash join algorithm with range-based
19//! data pruning to join two (potentially infinite) streams.
20//!
21//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate
22//! output ordering) and produces the join output according to the given join
23//! type and other options.
24//!
25//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations
26//! for both its children.
27
28use std::any::Any;
29use std::fmt::{self, Debug};
30use std::mem::{size_of, size_of_val};
31use std::sync::Arc;
32use std::task::{Context, Poll};
33use std::vec;
34
35use crate::common::SharedMemoryReservation;
36use crate::execution_plan::{boundedness_from_children, emission_type_from_children};
37use crate::joins::stream_join_utils::{
38    PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
39    calculate_filter_expr_intervals, combine_two_batches,
40    convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
41    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
42};
43use crate::joins::utils::{
44    BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn,
45    JoinOnRef, NoopBatchTransformer, StatefulStreamResult, apply_join_filter_to_indices,
46    build_batch_from_indices, build_join_schema, check_join_is_valid, equal_rows_arr,
47    symmetric_join_output_partitioning, update_hash,
48};
49use crate::projection::{
50    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
51    physical_to_column_exprs, update_join_filter, update_join_on,
52};
53use crate::{
54    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
55    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
56    joins::StreamJoinPartitionMode,
57    metrics::{ExecutionPlanMetricsSet, MetricsSet},
58};
59
60use arrow::array::{
61    ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
62    UInt64Array,
63};
64use arrow::compute::concat_batches;
65use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
66use arrow::record_batch::RecordBatch;
67use datafusion_common::hash_utils::create_hashes;
68use datafusion_common::utils::bisect;
69use datafusion_common::{
70    HashSet, JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err,
71    plan_err,
72};
73use datafusion_execution::TaskContext;
74use datafusion_execution::memory_pool::MemoryConsumer;
75use datafusion_expr::interval_arithmetic::Interval;
76use datafusion_physical_expr::equivalence::join_equivalence_properties;
77use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
78use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
79use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
80
81use ahash::RandomState;
82use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
83use futures::{Stream, StreamExt, ready};
84use parking_lot::Mutex;
85
86const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
87
88/// A symmetric hash join with range conditions is when both streams are hashed on the
89/// join key and the resulting hash tables are used to join the streams.
90/// The join is considered symmetric because the hash table is built on the join keys from both
91/// streams, and the matching of rows is based on the values of the join keys in both streams.
92/// This type of join is efficient in streaming context as it allows for fast lookups in the hash
93/// table, rather than having to scan through one or both of the streams to find matching rows, also it
94/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions),
95/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming
96/// data without any memory issues.
97///
98/// For each input stream, create a hash table.
99///   - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets.
100///   - Test if input is equal to a predefined set of other inputs.
101///   - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch].
102///   - Try to prune other side (probe) with new [RecordBatch].
103///   - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.),
104///     output the [RecordBatch] when a pruning happens or at the end of the data.
105///
106///
107/// ``` text
108///                        +-------------------------+
109///                        |                         |
110///   left stream ---------|  Left OneSideHashJoiner |---+
111///                        |                         |   |
112///                        +-------------------------+   |
113///                                                      |
114///                                                      |--------- Joined output
115///                                                      |
116///                        +-------------------------+   |
117///                        |                         |   |
118///  right stream ---------| Right OneSideHashJoiner |---+
119///                        |                         |
120///                        +-------------------------+
121///
122/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic
123/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range.
124///
125///
126///               PROBE SIDE          BUILD SIDE
127///                 BUFFER              BUFFER
128///             +-------------+     +------------+
129///             |             |     |            |    Unjoinable
130///             |             |     |            |    Range
131///             |             |     |            |
132///             |             |  |---------------------------------
133///             |             |  |  |            |
134///             |             |  |  |            |
135///             |             | /   |            |
136///             |             | |   |            |
137///             |             | |   |            |
138///             |             | |   |            |
139///             |             | |   |            |
140///             |             | |   |            |    Joinable
141///             |             |/    |            |    Range
142///             |             ||    |            |
143///             |+-----------+||    |            |
144///             || Record    ||     |            |
145///             || Batch     ||     |            |
146///             |+-----------+||    |            |
147///             +-------------+\    +------------+
148///                             |
149///                             \
150///                              |---------------------------------
151///
152///  This happens when range conditions are provided on sorted columns. E.g.
153///
154///        SELECT * FROM left_table, right_table
155///        ON
156///          left_key = right_key AND
157///          left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR
158///
159/// or
160///       SELECT * FROM left_table, right_table
161///        ON
162///          left_key = right_key AND
163///          left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10
164///
165/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to
166/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the
167/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios)
168/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning
169/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" ,
170/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending)
171/// than that can be dropped from the inner buffer.
172/// ```
173#[derive(Debug, Clone)]
174pub struct SymmetricHashJoinExec {
175    /// Left side stream
176    pub(crate) left: Arc<dyn ExecutionPlan>,
177    /// Right side stream
178    pub(crate) right: Arc<dyn ExecutionPlan>,
179    /// Set of common columns used to join on
180    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
181    /// Filters applied when finding matching rows
182    pub(crate) filter: Option<JoinFilter>,
183    /// How the join is performed
184    pub(crate) join_type: JoinType,
185    /// Shares the `RandomState` for the hashing algorithm
186    random_state: RandomState,
187    /// Execution metrics
188    metrics: ExecutionPlanMetricsSet,
189    /// Information of index and left / right placement of columns
190    column_indices: Vec<ColumnIndex>,
191    /// Defines the null equality for the join.
192    pub(crate) null_equality: NullEquality,
193    /// Left side sort expression(s)
194    pub(crate) left_sort_exprs: Option<LexOrdering>,
195    /// Right side sort expression(s)
196    pub(crate) right_sort_exprs: Option<LexOrdering>,
197    /// Partition Mode
198    mode: StreamJoinPartitionMode,
199    /// Cache holding plan properties like equivalences, output partitioning etc.
200    cache: PlanProperties,
201}
202
203impl SymmetricHashJoinExec {
204    /// Tries to create a new [SymmetricHashJoinExec].
205    /// # Error
206    /// This function errors when:
207    /// - It is not possible to join the left and right sides on keys `on`, or
208    /// - It fails to construct `SortedFilterExpr`s, or
209    /// - It fails to create the [ExprIntervalGraph].
210    #[expect(clippy::too_many_arguments)]
211    pub fn try_new(
212        left: Arc<dyn ExecutionPlan>,
213        right: Arc<dyn ExecutionPlan>,
214        on: JoinOn,
215        filter: Option<JoinFilter>,
216        join_type: &JoinType,
217        null_equality: NullEquality,
218        left_sort_exprs: Option<LexOrdering>,
219        right_sort_exprs: Option<LexOrdering>,
220        mode: StreamJoinPartitionMode,
221    ) -> Result<Self> {
222        let left_schema = left.schema();
223        let right_schema = right.schema();
224
225        // Error out if no "on" constraints are given:
226        if on.is_empty() {
227            return plan_err!(
228                "On constraints in SymmetricHashJoinExec should be non-empty"
229            );
230        }
231
232        // Check if the join is valid with the given on constraints:
233        check_join_is_valid(&left_schema, &right_schema, &on)?;
234
235        // Build the join schema from the left and right schemas:
236        let (schema, column_indices) =
237            build_join_schema(&left_schema, &right_schema, join_type);
238
239        // Initialize the random state for the join operation:
240        let random_state = RandomState::with_seeds(0, 0, 0, 0);
241        let schema = Arc::new(schema);
242        let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?;
243        Ok(SymmetricHashJoinExec {
244            left,
245            right,
246            on,
247            filter,
248            join_type: *join_type,
249            random_state,
250            metrics: ExecutionPlanMetricsSet::new(),
251            column_indices,
252            null_equality,
253            left_sort_exprs,
254            right_sort_exprs,
255            mode,
256            cache,
257        })
258    }
259
260    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
261    fn compute_properties(
262        left: &Arc<dyn ExecutionPlan>,
263        right: &Arc<dyn ExecutionPlan>,
264        schema: SchemaRef,
265        join_type: JoinType,
266        join_on: JoinOnRef,
267    ) -> Result<PlanProperties> {
268        // Calculate equivalence properties:
269        let eq_properties = join_equivalence_properties(
270            left.equivalence_properties().clone(),
271            right.equivalence_properties().clone(),
272            &join_type,
273            schema,
274            &[false, false],
275            // Has alternating probe side
276            None,
277            join_on,
278        )?;
279
280        let output_partitioning =
281            symmetric_join_output_partitioning(left, right, &join_type)?;
282
283        Ok(PlanProperties::new(
284            eq_properties,
285            output_partitioning,
286            emission_type_from_children([left, right]),
287            boundedness_from_children([left, right]),
288        ))
289    }
290
291    /// left stream
292    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
293        &self.left
294    }
295
296    /// right stream
297    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
298        &self.right
299    }
300
301    /// Set of common columns used to join on
302    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
303        &self.on
304    }
305
306    /// Filters applied before join output
307    pub fn filter(&self) -> Option<&JoinFilter> {
308        self.filter.as_ref()
309    }
310
311    /// How the join is performed
312    pub fn join_type(&self) -> &JoinType {
313        &self.join_type
314    }
315
316    /// Get null_equality
317    pub fn null_equality(&self) -> NullEquality {
318        self.null_equality
319    }
320
321    /// Get partition mode
322    pub fn partition_mode(&self) -> StreamJoinPartitionMode {
323        self.mode
324    }
325
326    /// Get left_sort_exprs
327    pub fn left_sort_exprs(&self) -> Option<&LexOrdering> {
328        self.left_sort_exprs.as_ref()
329    }
330
331    /// Get right_sort_exprs
332    pub fn right_sort_exprs(&self) -> Option<&LexOrdering> {
333        self.right_sort_exprs.as_ref()
334    }
335
336    /// Check if order information covers every column in the filter expression.
337    pub fn check_if_order_information_available(&self) -> Result<bool> {
338        if let Some(filter) = self.filter() {
339            let left = self.left();
340            if let Some(left_ordering) = left.output_ordering() {
341                let right = self.right();
342                if let Some(right_ordering) = right.output_ordering() {
343                    let left_convertible = convert_sort_expr_with_filter_schema(
344                        &JoinSide::Left,
345                        filter,
346                        &left.schema(),
347                        &left_ordering[0],
348                    )?
349                    .is_some();
350                    let right_convertible = convert_sort_expr_with_filter_schema(
351                        &JoinSide::Right,
352                        filter,
353                        &right.schema(),
354                        &right_ordering[0],
355                    )?
356                    .is_some();
357                    return Ok(left_convertible && right_convertible);
358                }
359            }
360        }
361        Ok(false)
362    }
363}
364
365impl DisplayAs for SymmetricHashJoinExec {
366    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
367        match t {
368            DisplayFormatType::Default | DisplayFormatType::Verbose => {
369                let display_filter = self.filter.as_ref().map_or_else(
370                    || "".to_string(),
371                    |f| format!(", filter={}", f.expression()),
372                );
373                let on = self
374                    .on
375                    .iter()
376                    .map(|(c1, c2)| format!("({c1}, {c2})"))
377                    .collect::<Vec<String>>()
378                    .join(", ");
379                write!(
380                    f,
381                    "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
382                    self.mode, self.join_type, on, display_filter
383                )
384            }
385            DisplayFormatType::TreeRender => {
386                let on = self
387                    .on
388                    .iter()
389                    .map(|(c1, c2)| {
390                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
391                    })
392                    .collect::<Vec<String>>()
393                    .join(", ");
394
395                writeln!(f, "mode={:?}", self.mode)?;
396                if *self.join_type() != JoinType::Inner {
397                    writeln!(f, "join_type={:?}", self.join_type)?;
398                }
399                writeln!(f, "on={on}")
400            }
401        }
402    }
403}
404
405impl ExecutionPlan for SymmetricHashJoinExec {
406    fn name(&self) -> &'static str {
407        "SymmetricHashJoinExec"
408    }
409
410    fn as_any(&self) -> &dyn Any {
411        self
412    }
413
414    fn properties(&self) -> &PlanProperties {
415        &self.cache
416    }
417
418    fn required_input_distribution(&self) -> Vec<Distribution> {
419        match self.mode {
420            StreamJoinPartitionMode::Partitioned => {
421                let (left_expr, right_expr) = self
422                    .on
423                    .iter()
424                    .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
425                    .unzip();
426                vec![
427                    Distribution::HashPartitioned(left_expr),
428                    Distribution::HashPartitioned(right_expr),
429                ]
430            }
431            StreamJoinPartitionMode::SinglePartition => {
432                vec![Distribution::SinglePartition, Distribution::SinglePartition]
433            }
434        }
435    }
436
437    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
438        vec![
439            self.left_sort_exprs
440                .as_ref()
441                .map(|e| OrderingRequirements::from(e.clone())),
442            self.right_sort_exprs
443                .as_ref()
444                .map(|e| OrderingRequirements::from(e.clone())),
445        ]
446    }
447
448    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
449        vec![&self.left, &self.right]
450    }
451
452    fn with_new_children(
453        self: Arc<Self>,
454        children: Vec<Arc<dyn ExecutionPlan>>,
455    ) -> Result<Arc<dyn ExecutionPlan>> {
456        Ok(Arc::new(SymmetricHashJoinExec::try_new(
457            Arc::clone(&children[0]),
458            Arc::clone(&children[1]),
459            self.on.clone(),
460            self.filter.clone(),
461            &self.join_type,
462            self.null_equality,
463            self.left_sort_exprs.clone(),
464            self.right_sort_exprs.clone(),
465            self.mode,
466        )?))
467    }
468
469    fn metrics(&self) -> Option<MetricsSet> {
470        Some(self.metrics.clone_inner())
471    }
472
473    fn statistics(&self) -> Result<Statistics> {
474        // TODO stats: it is not possible in general to know the output size of joins
475        Ok(Statistics::new_unknown(&self.schema()))
476    }
477
478    fn execute(
479        &self,
480        partition: usize,
481        context: Arc<TaskContext>,
482    ) -> Result<SendableRecordBatchStream> {
483        let left_partitions = self.left.output_partitioning().partition_count();
484        let right_partitions = self.right.output_partitioning().partition_count();
485        assert_eq_or_internal_err!(
486            left_partitions,
487            right_partitions,
488            "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
489                 consider using RepartitionExec"
490        );
491        // If `filter_state` and `filter` are both present, then calculate sorted
492        // filter expressions for both sides, and build an expression graph.
493        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
494            self.left_sort_exprs(),
495            self.right_sort_exprs(),
496            &self.filter,
497        ) {
498            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
499                let (left, right, graph) = prepare_sorted_exprs(
500                    filter,
501                    &self.left,
502                    &self.right,
503                    left_sort_exprs,
504                    right_sort_exprs,
505                )?;
506                (Some(left), Some(right), Some(graph))
507            }
508            // If `filter_state` or `filter` is not present, then return None
509            // for all three values:
510            _ => (None, None, None),
511        };
512
513        let (on_left, on_right) = self.on.iter().cloned().unzip();
514
515        let left_side_joiner =
516            OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
517        let right_side_joiner =
518            OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
519
520        let left_stream = self.left.execute(partition, Arc::clone(&context))?;
521
522        let right_stream = self.right.execute(partition, Arc::clone(&context))?;
523
524        let batch_size = context.session_config().batch_size();
525        let enforce_batch_size_in_joins =
526            context.session_config().enforce_batch_size_in_joins();
527
528        let reservation = Arc::new(Mutex::new(
529            MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
530                .register(context.memory_pool()),
531        ));
532        if let Some(g) = graph.as_ref() {
533            reservation.lock().try_grow(g.size())?;
534        }
535
536        if enforce_batch_size_in_joins {
537            Ok(Box::pin(SymmetricHashJoinStream {
538                left_stream,
539                right_stream,
540                schema: self.schema(),
541                filter: self.filter.clone(),
542                join_type: self.join_type,
543                random_state: self.random_state.clone(),
544                left: left_side_joiner,
545                right: right_side_joiner,
546                column_indices: self.column_indices.clone(),
547                metrics: StreamJoinMetrics::new(partition, &self.metrics),
548                graph,
549                left_sorted_filter_expr,
550                right_sorted_filter_expr,
551                null_equality: self.null_equality,
552                state: SHJStreamState::PullRight,
553                reservation,
554                batch_transformer: BatchSplitter::new(batch_size),
555            }))
556        } else {
557            Ok(Box::pin(SymmetricHashJoinStream {
558                left_stream,
559                right_stream,
560                schema: self.schema(),
561                filter: self.filter.clone(),
562                join_type: self.join_type,
563                random_state: self.random_state.clone(),
564                left: left_side_joiner,
565                right: right_side_joiner,
566                column_indices: self.column_indices.clone(),
567                metrics: StreamJoinMetrics::new(partition, &self.metrics),
568                graph,
569                left_sorted_filter_expr,
570                right_sorted_filter_expr,
571                null_equality: self.null_equality,
572                state: SHJStreamState::PullRight,
573                reservation,
574                batch_transformer: NoopBatchTransformer::new(),
575            }))
576        }
577    }
578
579    /// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done,
580    /// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan.
581    /// Otherwise, it returns None.
582    fn try_swapping_with_projection(
583        &self,
584        projection: &ProjectionExec,
585    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
586        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
587        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
588        else {
589            return Ok(None);
590        };
591
592        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
593            self.left().schema().fields().len(),
594            &projection_as_columns,
595        );
596
597        if !join_allows_pushdown(
598            &projection_as_columns,
599            &self.schema(),
600            far_right_left_col_ind,
601            far_left_right_col_ind,
602        ) {
603            return Ok(None);
604        }
605
606        let Some(new_on) = update_join_on(
607            &projection_as_columns[0..=far_right_left_col_ind as _],
608            &projection_as_columns[far_left_right_col_ind as _..],
609            self.on(),
610            self.left().schema().fields().len(),
611        ) else {
612            return Ok(None);
613        };
614
615        let new_filter = if let Some(filter) = self.filter() {
616            match update_join_filter(
617                &projection_as_columns[0..=far_right_left_col_ind as _],
618                &projection_as_columns[far_left_right_col_ind as _..],
619                filter,
620                self.left().schema().fields().len(),
621            ) {
622                Some(updated_filter) => Some(updated_filter),
623                None => return Ok(None),
624            }
625        } else {
626            None
627        };
628
629        let (new_left, new_right) = new_join_children(
630            &projection_as_columns,
631            far_right_left_col_ind,
632            far_left_right_col_ind,
633            self.left(),
634            self.right(),
635        )?;
636
637        SymmetricHashJoinExec::try_new(
638            Arc::new(new_left),
639            Arc::new(new_right),
640            new_on,
641            new_filter,
642            self.join_type(),
643            self.null_equality(),
644            self.right().output_ordering().cloned(),
645            self.left().output_ordering().cloned(),
646            self.partition_mode(),
647        )
648        .map(|e| Some(Arc::new(e) as _))
649    }
650}
651
652/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
653struct SymmetricHashJoinStream<T> {
654    /// Input streams
655    left_stream: SendableRecordBatchStream,
656    right_stream: SendableRecordBatchStream,
657    /// Input schema
658    schema: Arc<Schema>,
659    /// join filter
660    filter: Option<JoinFilter>,
661    /// type of the join
662    join_type: JoinType,
663    // left hash joiner
664    left: OneSideHashJoiner,
665    /// right hash joiner
666    right: OneSideHashJoiner,
667    /// Information of index and left / right placement of columns
668    column_indices: Vec<ColumnIndex>,
669    // Expression graph for range pruning.
670    graph: Option<ExprIntervalGraph>,
671    // Left globally sorted filter expr
672    left_sorted_filter_expr: Option<SortedFilterExpr>,
673    // Right globally sorted filter expr
674    right_sorted_filter_expr: Option<SortedFilterExpr>,
675    /// Random state used for hashing initialization
676    random_state: RandomState,
677    /// Defines the null equality for the join.
678    null_equality: NullEquality,
679    /// Metrics
680    metrics: StreamJoinMetrics,
681    /// Memory reservation
682    reservation: SharedMemoryReservation,
683    /// State machine for input execution
684    state: SHJStreamState,
685    /// Transforms the output batch before returning.
686    batch_transformer: T,
687}
688
689impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
690    for SymmetricHashJoinStream<T>
691{
692    fn schema(&self) -> SchemaRef {
693        Arc::clone(&self.schema)
694    }
695}
696
697impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> {
698    type Item = Result<RecordBatch>;
699
700    fn poll_next(
701        mut self: std::pin::Pin<&mut Self>,
702        cx: &mut Context<'_>,
703    ) -> Poll<Option<Self::Item>> {
704        self.poll_next_impl(cx)
705    }
706}
707
708/// Determine the pruning length for `buffer`.
709///
710/// This function evaluates the build side filter expression, converts the
711/// result into an array and determines the pruning length by performing a
712/// binary search on the array.
713///
714/// # Arguments
715///
716/// * `buffer`: The record batch to be pruned.
717/// * `build_side_filter_expr`: The filter expression on the build side used
718///   to determine the pruning length.
719///
720/// # Returns
721///
722/// A [Result] object that contains the pruning length. The function will return
723/// an error if
724/// - there is an issue evaluating the build side filter expression;
725/// - there is an issue converting the build side filter expression into an array
726fn determine_prune_length(
727    buffer: &RecordBatch,
728    build_side_filter_expr: &SortedFilterExpr,
729) -> Result<usize> {
730    let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
731    let interval = build_side_filter_expr.interval();
732    // Evaluate the build side filter expression and convert it into an array
733    let batch_arr = origin_sorted_expr
734        .expr
735        .evaluate(buffer)?
736        .into_array(buffer.num_rows())?;
737
738    // Get the lower or upper interval based on the sort direction
739    let target = if origin_sorted_expr.options.descending {
740        interval.upper().clone()
741    } else {
742        interval.lower().clone()
743    };
744
745    // Perform binary search on the array to determine the length of the record batch to be pruned
746    bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
747}
748
749/// This method determines if the result of the join should be produced in the final step or not.
750///
751/// # Arguments
752///
753/// * `build_side` - Enum indicating the side of the join used as the build side.
754/// * `join_type` - Enum indicating the type of join to be performed.
755///
756/// # Returns
757///
758/// A boolean indicating whether the result of the join should be produced in the final step or not.
759/// The result will be true if the build side is JoinSide::Left and the join type is one of
760/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi.
761/// If the build side is JoinSide::Right, the result will be true if the join type
762/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi.
763fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
764    if build_side == JoinSide::Left {
765        matches!(
766            join_type,
767            JoinType::Left
768                | JoinType::LeftAnti
769                | JoinType::Full
770                | JoinType::LeftSemi
771                | JoinType::LeftMark
772        )
773    } else {
774        matches!(
775            join_type,
776            JoinType::Right
777                | JoinType::RightAnti
778                | JoinType::Full
779                | JoinType::RightSemi
780                | JoinType::RightMark
781        )
782    }
783}
784
785/// Calculate indices by join type.
786///
787/// This method returns a tuple of two arrays: build and probe indices.
788/// The length of both arrays will be the same.
789///
790/// # Arguments
791///
792/// * `build_side`: Join side which defines the build side.
793/// * `prune_length`: Length of the prune data.
794/// * `visited_rows`: Hash set of visited rows of the build side.
795/// * `deleted_offset`: Deleted offset of the build side.
796/// * `join_type`: The type of join to be performed.
797///
798/// # Returns
799///
800/// A tuple of two arrays of primitive types representing the build and probe indices.
801fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
802    build_side: JoinSide,
803    prune_length: usize,
804    visited_rows: &HashSet<usize>,
805    deleted_offset: usize,
806    join_type: JoinType,
807) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
808where
809    NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
810{
811    // Store the result in a tuple
812    let result = match (build_side, join_type) {
813        // For a mark join we “mark” each build‐side row with a dummy 0 in the probe‐side index
814        // if it ever matched. For example, if
815        //
816        // prune_length = 5
817        // deleted_offset = 0
818        // visited_rows = {1, 3}
819        //
820        // then we produce:
821        //
822        // build_indices = [0, 1, 2, 3, 4]
823        // probe_indices = [None, Some(0), None, Some(0), None]
824        //
825        // Example: for each build row i in [0..5):
826        //   – We always output its own index i in `build_indices`
827        //   – We output `Some(0)` in `probe_indices[i]` if row i was ever visited, else `None`
828        (JoinSide::Left, JoinType::LeftMark) => {
829            let build_indices = (0..prune_length)
830                .map(L::Native::from_usize)
831                .collect::<PrimitiveArray<L>>();
832            let probe_indices = (0..prune_length)
833                .map(|idx| {
834                    // For mark join we output a dummy index 0 to indicate the row had a match
835                    visited_rows
836                        .contains(&(idx + deleted_offset))
837                        .then_some(R::Native::from_usize(0).unwrap())
838                })
839                .collect();
840            (build_indices, probe_indices)
841        }
842        (JoinSide::Right, JoinType::RightMark) => {
843            let build_indices = (0..prune_length)
844                .map(L::Native::from_usize)
845                .collect::<PrimitiveArray<L>>();
846            let probe_indices = (0..prune_length)
847                .map(|idx| {
848                    // For mark join we output a dummy index 0 to indicate the row had a match
849                    visited_rows
850                        .contains(&(idx + deleted_offset))
851                        .then_some(R::Native::from_usize(0).unwrap())
852                })
853                .collect();
854            (build_indices, probe_indices)
855        }
856        // In the case of `Left` or `Right` join, or `Full` join, get the anti indices
857        (JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
858        | (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
859        | (_, JoinType::Full) => {
860            let build_unmatched_indices =
861                get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
862            let mut builder =
863                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
864            builder.append_nulls(build_unmatched_indices.len());
865            let probe_indices = builder.finish();
866            (build_unmatched_indices, probe_indices)
867        }
868        // In the case of `LeftSemi` or `RightSemi` join, get the semi indices
869        (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
870            let build_unmatched_indices =
871                get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
872            let mut builder =
873                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
874            builder.append_nulls(build_unmatched_indices.len());
875            let probe_indices = builder.finish();
876            (build_unmatched_indices, probe_indices)
877        }
878        // The case of other join types is not considered
879        _ => unreachable!(),
880    };
881    Ok(result)
882}
883
884/// This function produces unmatched record results based on the build side,
885/// join type and other parameters.
886///
887/// The method uses first `prune_length` rows from the build side input buffer
888/// to produce results.
889///
890/// # Arguments
891///
892/// * `output_schema` - The schema of the final output record batch.
893/// * `prune_length` - The length of the determined prune length.
894/// * `probe_schema` - The schema of the probe [RecordBatch].
895/// * `join_type` - The type of join to be performed.
896/// * `column_indices` - Indices of columns that are being joined.
897///
898/// # Returns
899///
900/// * `Option<RecordBatch>` - The final output record batch if required, otherwise [None].
901pub(crate) fn build_side_determined_results(
902    build_hash_joiner: &OneSideHashJoiner,
903    output_schema: &SchemaRef,
904    prune_length: usize,
905    probe_schema: SchemaRef,
906    join_type: JoinType,
907    column_indices: &[ColumnIndex],
908) -> Result<Option<RecordBatch>> {
909    // Check if we need to produce a result in the final output:
910    if prune_length > 0
911        && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
912    {
913        // Calculate the indices for build and probe sides based on join type and build side:
914        let (build_indices, probe_indices) = calculate_indices_by_join_type(
915            build_hash_joiner.build_side,
916            prune_length,
917            &build_hash_joiner.visited_rows,
918            build_hash_joiner.deleted_offset,
919            join_type,
920        )?;
921
922        // Create an empty probe record batch:
923        let empty_probe_batch = RecordBatch::new_empty(probe_schema);
924        // Build the final result from the indices of build and probe sides:
925        build_batch_from_indices(
926            output_schema.as_ref(),
927            &build_hash_joiner.input_buffer,
928            &empty_probe_batch,
929            &build_indices,
930            &probe_indices,
931            column_indices,
932            build_hash_joiner.build_side,
933        )
934        .map(|batch| (batch.num_rows() > 0).then_some(batch))
935    } else {
936        // If we don't need to produce a result, return None
937        Ok(None)
938    }
939}
940
941/// This method performs a join between the build side input buffer and the probe side batch.
942///
943/// # Arguments
944///
945/// * `build_hash_joiner` - Build side hash joiner
946/// * `probe_hash_joiner` - Probe side hash joiner
947/// * `schema` - A reference to the schema of the output record batch.
948/// * `join_type` - The type of join to be performed.
949/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
950/// * `filter` - An optional filter on the join condition.
951/// * `probe_batch` - The second record batch to be joined.
952/// * `column_indices` - An array of columns to be selected for the result of the join.
953/// * `random_state` - The random state for the join.
954/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
955///
956/// # Returns
957///
958/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
959/// If the join type is one of the above four, the function will return [None].
960#[expect(clippy::too_many_arguments)]
961pub(crate) fn join_with_probe_batch(
962    build_hash_joiner: &mut OneSideHashJoiner,
963    probe_hash_joiner: &mut OneSideHashJoiner,
964    schema: &SchemaRef,
965    join_type: JoinType,
966    filter: Option<&JoinFilter>,
967    probe_batch: &RecordBatch,
968    column_indices: &[ColumnIndex],
969    random_state: &RandomState,
970    null_equality: NullEquality,
971) -> Result<Option<RecordBatch>> {
972    if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
973        return Ok(None);
974    }
975    let (build_indices, probe_indices) = lookup_join_hashmap(
976        &build_hash_joiner.hashmap,
977        &build_hash_joiner.input_buffer,
978        probe_batch,
979        &build_hash_joiner.on,
980        &probe_hash_joiner.on,
981        random_state,
982        null_equality,
983        &mut build_hash_joiner.hashes_buffer,
984        Some(build_hash_joiner.deleted_offset),
985    )?;
986
987    let (build_indices, probe_indices) = if let Some(filter) = filter {
988        apply_join_filter_to_indices(
989            &build_hash_joiner.input_buffer,
990            probe_batch,
991            build_indices,
992            probe_indices,
993            filter,
994            build_hash_joiner.build_side,
995            None,
996        )?
997    } else {
998        (build_indices, probe_indices)
999    };
1000
1001    if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
1002        record_visited_indices(
1003            &mut build_hash_joiner.visited_rows,
1004            build_hash_joiner.deleted_offset,
1005            &build_indices,
1006        );
1007    }
1008    if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
1009        record_visited_indices(
1010            &mut probe_hash_joiner.visited_rows,
1011            probe_hash_joiner.offset,
1012            &probe_indices,
1013        );
1014    }
1015    if matches!(
1016        join_type,
1017        JoinType::LeftAnti
1018            | JoinType::RightAnti
1019            | JoinType::LeftSemi
1020            | JoinType::LeftMark
1021            | JoinType::RightSemi
1022            | JoinType::RightMark
1023    ) {
1024        Ok(None)
1025    } else {
1026        build_batch_from_indices(
1027            schema,
1028            &build_hash_joiner.input_buffer,
1029            probe_batch,
1030            &build_indices,
1031            &probe_indices,
1032            column_indices,
1033            build_hash_joiner.build_side,
1034        )
1035        .map(|batch| (batch.num_rows() > 0).then_some(batch))
1036    }
1037}
1038
1039/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential
1040/// hash collisions.
1041///
1042/// # Arguments
1043///
1044/// * `build_hashmap` - hashmap collected from build side data.
1045/// * `build_batch` - Build side record batch.
1046/// * `probe_batch` - Probe side record batch.
1047/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join.
1048/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
1049/// * `random_state` - The random state for the join.
1050/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
1051/// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
1052/// * `deleted_offset` - deleted offset for build side data.
1053///
1054/// # Returns
1055///
1056/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side,
1057/// matched by join key columns.
1058#[expect(clippy::too_many_arguments)]
1059fn lookup_join_hashmap(
1060    build_hashmap: &PruningJoinHashMap,
1061    build_batch: &RecordBatch,
1062    probe_batch: &RecordBatch,
1063    build_on: &[PhysicalExprRef],
1064    probe_on: &[PhysicalExprRef],
1065    random_state: &RandomState,
1066    null_equality: NullEquality,
1067    hashes_buffer: &mut Vec<u64>,
1068    deleted_offset: Option<usize>,
1069) -> Result<(UInt64Array, UInt32Array)> {
1070    let keys_values = evaluate_expressions_to_arrays(probe_on, probe_batch)?;
1071    let build_join_values = evaluate_expressions_to_arrays(build_on, build_batch)?;
1072
1073    hashes_buffer.clear();
1074    hashes_buffer.resize(probe_batch.num_rows(), 0);
1075    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1076
1077    // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm
1078    // will return build indices for each probe row in a reverse order as such:
1079    // Build Indices: [5, 4, 3]
1080    // Probe Indices: [1, 1, 1]
1081    //
1082    // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side.
1083    // Let's consider probe rows [0,1] as an example:
1084    //
1085    // When the probe iteration sequence is reversed, the following pairings can be derived:
1086    //
1087    // For probe row 1:
1088    //     (5, 1)
1089    //     (4, 1)
1090    //     (3, 1)
1091    //
1092    // For probe row 0:
1093    //     (5, 0)
1094    //     (4, 0)
1095    //     (3, 0)
1096    //
1097    // After reversing both sets of indices, we obtain reversed indices:
1098    //
1099    //     (3,0)
1100    //     (4,0)
1101    //     (5,0)
1102    //     (3,1)
1103    //     (4,1)
1104    //     (5,1)
1105    //
1106    // With this approach, the lexicographic order on both the probe side and the build side is preserved.
1107    let (mut matched_probe, mut matched_build) = build_hashmap.get_matched_indices(
1108        Box::new(hash_values.iter().enumerate().rev()),
1109        deleted_offset,
1110    );
1111
1112    matched_probe.reverse();
1113    matched_build.reverse();
1114
1115    let build_indices: UInt64Array = matched_build.into();
1116    let probe_indices: UInt32Array = matched_probe.into();
1117
1118    let (build_indices, probe_indices) = equal_rows_arr(
1119        &build_indices,
1120        &probe_indices,
1121        &build_join_values,
1122        &keys_values,
1123        null_equality,
1124    )?;
1125
1126    Ok((build_indices, probe_indices))
1127}
1128
1129pub struct OneSideHashJoiner {
1130    /// Build side
1131    build_side: JoinSide,
1132    /// Input record batch buffer
1133    pub input_buffer: RecordBatch,
1134    /// Columns from the side
1135    pub(crate) on: Vec<PhysicalExprRef>,
1136    /// Hashmap
1137    pub(crate) hashmap: PruningJoinHashMap,
1138    /// Reuse the hashes buffer
1139    pub(crate) hashes_buffer: Vec<u64>,
1140    /// Matched rows
1141    pub(crate) visited_rows: HashSet<usize>,
1142    /// Offset
1143    pub(crate) offset: usize,
1144    /// Deleted offset
1145    pub(crate) deleted_offset: usize,
1146}
1147
1148impl OneSideHashJoiner {
1149    pub fn size(&self) -> usize {
1150        let mut size = 0;
1151        size += size_of_val(self);
1152        size += size_of_val(&self.build_side);
1153        size += self.input_buffer.get_array_memory_size();
1154        size += size_of_val(&self.on);
1155        size += self.hashmap.size();
1156        size += self.hashes_buffer.capacity() * size_of::<u64>();
1157        size += self.visited_rows.capacity() * size_of::<usize>();
1158        size += size_of_val(&self.offset);
1159        size += size_of_val(&self.deleted_offset);
1160        size
1161    }
1162    pub fn new(
1163        build_side: JoinSide,
1164        on: Vec<PhysicalExprRef>,
1165        schema: SchemaRef,
1166    ) -> Self {
1167        Self {
1168            build_side,
1169            input_buffer: RecordBatch::new_empty(schema),
1170            on,
1171            hashmap: PruningJoinHashMap::with_capacity(0),
1172            hashes_buffer: vec![],
1173            visited_rows: HashSet::new(),
1174            offset: 0,
1175            deleted_offset: 0,
1176        }
1177    }
1178
1179    /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch.
1180    ///
1181    /// # Arguments
1182    ///
1183    /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer
1184    /// * `random_state` - The random state used to hash values
1185    ///
1186    /// # Returns
1187    ///
1188    /// Returns a [Result] encapsulating any intermediate errors.
1189    pub(crate) fn update_internal_state(
1190        &mut self,
1191        batch: &RecordBatch,
1192        random_state: &RandomState,
1193    ) -> Result<()> {
1194        // Merge the incoming batch with the existing input buffer:
1195        self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
1196        // Resize the hashes buffer to the number of rows in the incoming batch:
1197        self.hashes_buffer.resize(batch.num_rows(), 0);
1198        // Get allocation_info before adding the item
1199        // Update the hashmap with the join key values and hashes of the incoming batch:
1200        update_hash(
1201            &self.on,
1202            batch,
1203            &mut self.hashmap,
1204            self.offset,
1205            random_state,
1206            &mut self.hashes_buffer,
1207            self.deleted_offset,
1208            false,
1209        )?;
1210        Ok(())
1211    }
1212
1213    /// Calculate prune length.
1214    ///
1215    /// # Arguments
1216    ///
1217    /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression..
1218    /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression.
1219    /// * `graph` - A mutable reference to the physical expression graph.
1220    ///
1221    /// # Returns
1222    ///
1223    /// A Result object that contains the pruning length.
1224    pub(crate) fn calculate_prune_length_with_probe_batch(
1225        &mut self,
1226        build_side_sorted_filter_expr: &mut SortedFilterExpr,
1227        probe_side_sorted_filter_expr: &mut SortedFilterExpr,
1228        graph: &mut ExprIntervalGraph,
1229    ) -> Result<usize> {
1230        // Return early if the input buffer is empty:
1231        if self.input_buffer.num_rows() == 0 {
1232            return Ok(0);
1233        }
1234        // Process the build and probe side sorted filter expressions if both are present:
1235        // Collect the sorted filter expressions into a vector of (node_index, interval) tuples:
1236        let mut filter_intervals = vec![];
1237        for expr in [
1238            &build_side_sorted_filter_expr,
1239            &probe_side_sorted_filter_expr,
1240        ] {
1241            filter_intervals.push((expr.node_index(), expr.interval().clone()))
1242        }
1243        // Update the physical expression graph using the join filter intervals:
1244        graph.update_ranges(&mut filter_intervals, Interval::TRUE)?;
1245        // Extract the new join filter interval for the build side:
1246        let calculated_build_side_interval = filter_intervals.remove(0).1;
1247        // If the intervals have not changed, return early without pruning:
1248        if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
1249            return Ok(0);
1250        }
1251        // Update the build side interval and determine the pruning length:
1252        build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
1253
1254        determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
1255    }
1256
1257    pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
1258        // Prune the hash values:
1259        self.hashmap.prune_hash_values(
1260            prune_length,
1261            self.deleted_offset as u64,
1262            HASHMAP_SHRINK_SCALE_FACTOR,
1263        );
1264        // Remove pruned rows from the visited rows set:
1265        for row in self.deleted_offset..(self.deleted_offset + prune_length) {
1266            self.visited_rows.remove(&row);
1267        }
1268        // Update the input buffer after pruning:
1269        self.input_buffer = self
1270            .input_buffer
1271            .slice(prune_length, self.input_buffer.num_rows() - prune_length);
1272        // Increment the deleted offset:
1273        self.deleted_offset += prune_length;
1274        Ok(())
1275    }
1276}
1277
1278/// `SymmetricHashJoinStream` manages incremental join operations between two
1279/// streams. Unlike traditional join approaches that need to scan one side of
1280/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates
1281/// more dynamic join operations by working with streams as they emit data. This
1282/// approach allows for more efficient processing, particularly in scenarios
1283/// where waiting for complete data materialization is not feasible or optimal.
1284/// The trait provides a framework for handling various states of such a join
1285/// process, ensuring that join logic is efficiently executed as data becomes
1286/// available from either stream.
1287///
1288/// This implementation performs eager joins of data from two different asynchronous
1289/// streams, typically referred to as left and right streams. The implementation
1290/// provides a comprehensive set of methods to control and execute the join
1291/// process, leveraging the states defined in `SHJStreamState`. Methods are
1292/// primarily focused on asynchronously fetching data batches from each stream,
1293/// processing them, and managing transitions between various states of the join.
1294///
1295/// This implementations use a state machine approach to navigate different
1296/// stages of the join operation, handling data from both streams and determining
1297/// when the join completes.
1298///
1299/// State Transitions:
1300/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
1301///   - In `fetch_next_from_left_stream`, when fetching a batch from the left stream:
1302///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
1303///       processing the batch.
1304///     - On error (`Some(Err(e))`), the error is returned, and the state remains
1305///       unchanged.
1306///     - On no data (`None`), state changes to `LeftExhausted`, returning `Continue`
1307///       to proceed with the join process.
1308/// - From `PullRight` to `PullLeft` or `RightExhausted`:
1309///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
1310///     - If a batch is available, state changes to `PullLeft` for processing.
1311///     - On error, the error is returned without changing the state.
1312///     - If right stream is exhausted (`None`), state transitions to `RightExhausted`,
1313///       with a `Continue` result.
1314/// - Handling `RightExhausted` and `LeftExhausted`:
1315///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios
1316///     when streams are exhausted:
1317///     - They attempt to continue processing with the other stream.
1318///     - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`.
1319/// - Transition to `BothExhausted { final_result: true }`:
1320///   - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are
1321///     exhausted, indicating completion of processing and availability of final results.
1322impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
1323    /// Implements the main polling logic for the join stream.
1324    ///
1325    /// This method continuously checks the state of the join stream and
1326    /// acts accordingly by delegating the handling to appropriate sub-methods
1327    /// depending on the current state.
1328    ///
1329    /// # Arguments
1330    ///
1331    /// * `cx` - A context that facilitates cooperative non-blocking execution within a task.
1332    ///
1333    /// # Returns
1334    ///
1335    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a `RecordBatch` or None.
1336    fn poll_next_impl(
1337        &mut self,
1338        cx: &mut Context<'_>,
1339    ) -> Poll<Option<Result<RecordBatch>>> {
1340        loop {
1341            match self.batch_transformer.next() {
1342                None => {
1343                    let result = match self.state() {
1344                        SHJStreamState::PullRight => {
1345                            ready!(self.fetch_next_from_right_stream(cx))
1346                        }
1347                        SHJStreamState::PullLeft => {
1348                            ready!(self.fetch_next_from_left_stream(cx))
1349                        }
1350                        SHJStreamState::RightExhausted => {
1351                            ready!(self.handle_right_stream_end(cx))
1352                        }
1353                        SHJStreamState::LeftExhausted => {
1354                            ready!(self.handle_left_stream_end(cx))
1355                        }
1356                        SHJStreamState::BothExhausted {
1357                            final_result: false,
1358                        } => self.prepare_for_final_results_after_exhaustion(),
1359                        SHJStreamState::BothExhausted { final_result: true } => {
1360                            return Poll::Ready(None);
1361                        }
1362                    };
1363
1364                    match result? {
1365                        StatefulStreamResult::Ready(None) => {
1366                            return Poll::Ready(None);
1367                        }
1368                        StatefulStreamResult::Ready(Some(batch)) => {
1369                            self.batch_transformer.set_batch(batch);
1370                        }
1371                        _ => {}
1372                    }
1373                }
1374                Some((batch, _)) => {
1375                    return self
1376                        .metrics
1377                        .baseline_metrics
1378                        .record_poll(Poll::Ready(Some(Ok(batch))));
1379                }
1380            }
1381        }
1382    }
1383    /// Asynchronously pulls the next batch from the right stream.
1384    ///
1385    /// This default implementation checks for the next value in the right stream.
1386    /// If a batch is found, the state is switched to `PullLeft`, and the batch handling
1387    /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`.
1388    ///
1389    /// # Returns
1390    ///
1391    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1392    fn fetch_next_from_right_stream(
1393        &mut self,
1394        cx: &mut Context<'_>,
1395    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1396        match ready!(self.right_stream().poll_next_unpin(cx)) {
1397            Some(Ok(batch)) => {
1398                if batch.num_rows() == 0 {
1399                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1400                }
1401                self.set_state(SHJStreamState::PullLeft);
1402                Poll::Ready(self.process_batch_from_right(&batch))
1403            }
1404            Some(Err(e)) => Poll::Ready(Err(e)),
1405            None => {
1406                self.set_state(SHJStreamState::RightExhausted);
1407                Poll::Ready(Ok(StatefulStreamResult::Continue))
1408            }
1409        }
1410    }
1411
1412    /// Asynchronously pulls the next batch from the left stream.
1413    ///
1414    /// This default implementation checks for the next value in the left stream.
1415    /// If a batch is found, the state is switched to `PullRight`, and the batch handling
1416    /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`.
1417    ///
1418    /// # Returns
1419    ///
1420    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1421    fn fetch_next_from_left_stream(
1422        &mut self,
1423        cx: &mut Context<'_>,
1424    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1425        match ready!(self.left_stream().poll_next_unpin(cx)) {
1426            Some(Ok(batch)) => {
1427                if batch.num_rows() == 0 {
1428                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1429                }
1430                self.set_state(SHJStreamState::PullRight);
1431                Poll::Ready(self.process_batch_from_left(&batch))
1432            }
1433            Some(Err(e)) => Poll::Ready(Err(e)),
1434            None => {
1435                self.set_state(SHJStreamState::LeftExhausted);
1436                Poll::Ready(Ok(StatefulStreamResult::Continue))
1437            }
1438        }
1439    }
1440
1441    /// Asynchronously handles the scenario when the right stream is exhausted.
1442    ///
1443    /// In this default implementation, when the right stream is exhausted, it attempts
1444    /// to pull from the left stream. If a batch is found in the left stream, it delegates
1445    /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set
1446    /// to indicate both streams are exhausted without final results yet.
1447    ///
1448    /// # Returns
1449    ///
1450    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1451    fn handle_right_stream_end(
1452        &mut self,
1453        cx: &mut Context<'_>,
1454    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1455        match ready!(self.left_stream().poll_next_unpin(cx)) {
1456            Some(Ok(batch)) => {
1457                if batch.num_rows() == 0 {
1458                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1459                }
1460                Poll::Ready(self.process_batch_after_right_end(&batch))
1461            }
1462            Some(Err(e)) => Poll::Ready(Err(e)),
1463            None => {
1464                self.set_state(SHJStreamState::BothExhausted {
1465                    final_result: false,
1466                });
1467                Poll::Ready(Ok(StatefulStreamResult::Continue))
1468            }
1469        }
1470    }
1471
1472    /// Asynchronously handles the scenario when the left stream is exhausted.
1473    ///
1474    /// When the left stream is exhausted, this default
1475    /// implementation tries to pull from the right stream and delegates the batch
1476    /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state
1477    /// is updated to indicate so.
1478    ///
1479    /// # Returns
1480    ///
1481    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1482    fn handle_left_stream_end(
1483        &mut self,
1484        cx: &mut Context<'_>,
1485    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1486        match ready!(self.right_stream().poll_next_unpin(cx)) {
1487            Some(Ok(batch)) => {
1488                if batch.num_rows() == 0 {
1489                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1490                }
1491                Poll::Ready(self.process_batch_after_left_end(&batch))
1492            }
1493            Some(Err(e)) => Poll::Ready(Err(e)),
1494            None => {
1495                self.set_state(SHJStreamState::BothExhausted {
1496                    final_result: false,
1497                });
1498                Poll::Ready(Ok(StatefulStreamResult::Continue))
1499            }
1500        }
1501    }
1502
1503    /// Handles the state when both streams are exhausted and final results are yet to be produced.
1504    ///
1505    /// This default implementation switches the state to indicate both streams are
1506    /// exhausted with final results and then invokes the handling for this specific
1507    /// scenario via `process_batches_before_finalization`.
1508    ///
1509    /// # Returns
1510    ///
1511    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
1512    fn prepare_for_final_results_after_exhaustion(
1513        &mut self,
1514    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1515        self.set_state(SHJStreamState::BothExhausted { final_result: true });
1516        self.process_batches_before_finalization()
1517    }
1518
1519    fn process_batch_from_right(
1520        &mut self,
1521        batch: &RecordBatch,
1522    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1523        self.perform_join_for_given_side(batch, JoinSide::Right)
1524            .map(|maybe_batch| {
1525                if maybe_batch.is_some() {
1526                    StatefulStreamResult::Ready(maybe_batch)
1527                } else {
1528                    StatefulStreamResult::Continue
1529                }
1530            })
1531    }
1532
1533    fn process_batch_from_left(
1534        &mut self,
1535        batch: &RecordBatch,
1536    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1537        self.perform_join_for_given_side(batch, JoinSide::Left)
1538            .map(|maybe_batch| {
1539                if maybe_batch.is_some() {
1540                    StatefulStreamResult::Ready(maybe_batch)
1541                } else {
1542                    StatefulStreamResult::Continue
1543                }
1544            })
1545    }
1546
1547    fn process_batch_after_left_end(
1548        &mut self,
1549        right_batch: &RecordBatch,
1550    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1551        self.process_batch_from_right(right_batch)
1552    }
1553
1554    fn process_batch_after_right_end(
1555        &mut self,
1556        left_batch: &RecordBatch,
1557    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1558        self.process_batch_from_left(left_batch)
1559    }
1560
1561    fn process_batches_before_finalization(
1562        &mut self,
1563    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1564        // Get the left side results:
1565        let left_result = build_side_determined_results(
1566            &self.left,
1567            &self.schema,
1568            self.left.input_buffer.num_rows(),
1569            self.right.input_buffer.schema(),
1570            self.join_type,
1571            &self.column_indices,
1572        )?;
1573        // Get the right side results:
1574        let right_result = build_side_determined_results(
1575            &self.right,
1576            &self.schema,
1577            self.right.input_buffer.num_rows(),
1578            self.left.input_buffer.schema(),
1579            self.join_type,
1580            &self.column_indices,
1581        )?;
1582
1583        // Combine the left and right results:
1584        let result = combine_two_batches(&self.schema, left_result, right_result)?;
1585
1586        // Return the result:
1587        if result.is_some() {
1588            return Ok(StatefulStreamResult::Ready(result));
1589        }
1590        Ok(StatefulStreamResult::Continue)
1591    }
1592
1593    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
1594        &mut self.right_stream
1595    }
1596
1597    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
1598        &mut self.left_stream
1599    }
1600
1601    fn set_state(&mut self, state: SHJStreamState) {
1602        self.state = state;
1603    }
1604
1605    fn state(&mut self) -> SHJStreamState {
1606        self.state.clone()
1607    }
1608
1609    fn size(&self) -> usize {
1610        let mut size = 0;
1611        size += size_of_val(&self.schema);
1612        size += size_of_val(&self.filter);
1613        size += size_of_val(&self.join_type);
1614        size += self.left.size();
1615        size += self.right.size();
1616        size += size_of_val(&self.column_indices);
1617        size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
1618        size += size_of_val(&self.left_sorted_filter_expr);
1619        size += size_of_val(&self.right_sorted_filter_expr);
1620        size += size_of_val(&self.random_state);
1621        size += size_of_val(&self.null_equality);
1622        size += size_of_val(&self.metrics);
1623        size
1624    }
1625
1626    /// Performs a join operation for the specified `probe_side` (either left or right).
1627    /// This function:
1628    /// 1. Determines which side is the probe and which is the build side.
1629    /// 2. Updates metrics based on the batch that was polled.
1630    /// 3. Executes the join with the given `probe_batch`.
1631    /// 4. Optionally computes anti-join results if all conditions are met.
1632    /// 5. Combines the results and returns a combined batch or `None` if no batch was produced.
1633    fn perform_join_for_given_side(
1634        &mut self,
1635        probe_batch: &RecordBatch,
1636        probe_side: JoinSide,
1637    ) -> Result<Option<RecordBatch>> {
1638        let (
1639            probe_hash_joiner,
1640            build_hash_joiner,
1641            probe_side_sorted_filter_expr,
1642            build_side_sorted_filter_expr,
1643            probe_side_metrics,
1644        ) = if probe_side.eq(&JoinSide::Left) {
1645            (
1646                &mut self.left,
1647                &mut self.right,
1648                &mut self.left_sorted_filter_expr,
1649                &mut self.right_sorted_filter_expr,
1650                &mut self.metrics.left,
1651            )
1652        } else {
1653            (
1654                &mut self.right,
1655                &mut self.left,
1656                &mut self.right_sorted_filter_expr,
1657                &mut self.left_sorted_filter_expr,
1658                &mut self.metrics.right,
1659            )
1660        };
1661        // Update the metrics for the stream that was polled:
1662        probe_side_metrics.input_batches.add(1);
1663        probe_side_metrics.input_rows.add(probe_batch.num_rows());
1664        // Update the internal state of the hash joiner for the build side:
1665        probe_hash_joiner.update_internal_state(probe_batch, &self.random_state)?;
1666        // Join the two sides:
1667        let equal_result = join_with_probe_batch(
1668            build_hash_joiner,
1669            probe_hash_joiner,
1670            &self.schema,
1671            self.join_type,
1672            self.filter.as_ref(),
1673            probe_batch,
1674            &self.column_indices,
1675            &self.random_state,
1676            self.null_equality,
1677        )?;
1678        // Increment the offset for the probe hash joiner:
1679        probe_hash_joiner.offset += probe_batch.num_rows();
1680
1681        let anti_result = if let (
1682            Some(build_side_sorted_filter_expr),
1683            Some(probe_side_sorted_filter_expr),
1684            Some(graph),
1685        ) = (
1686            build_side_sorted_filter_expr.as_mut(),
1687            probe_side_sorted_filter_expr.as_mut(),
1688            self.graph.as_mut(),
1689        ) {
1690            // Calculate filter intervals:
1691            calculate_filter_expr_intervals(
1692                &build_hash_joiner.input_buffer,
1693                build_side_sorted_filter_expr,
1694                probe_batch,
1695                probe_side_sorted_filter_expr,
1696            )?;
1697            let prune_length = build_hash_joiner
1698                .calculate_prune_length_with_probe_batch(
1699                    build_side_sorted_filter_expr,
1700                    probe_side_sorted_filter_expr,
1701                    graph,
1702                )?;
1703            let result = build_side_determined_results(
1704                build_hash_joiner,
1705                &self.schema,
1706                prune_length,
1707                probe_batch.schema(),
1708                self.join_type,
1709                &self.column_indices,
1710            )?;
1711            build_hash_joiner.prune_internal_state(prune_length)?;
1712            result
1713        } else {
1714            None
1715        };
1716
1717        // Combine results:
1718        let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
1719        let capacity = self.size();
1720        self.metrics.stream_memory_usage.set(capacity);
1721        self.reservation.lock().try_resize(capacity)?;
1722        Ok(result)
1723    }
1724}
1725
1726/// Represents the various states of an symmetric hash join stream operation.
1727///
1728/// This enum is used to track the current state of streaming during a join
1729/// operation. It provides indicators as to which side of the join needs to be
1730/// pulled next or if one (or both) sides have been exhausted. This allows
1731/// for efficient management of resources and optimal performance during the
1732/// join process.
1733#[derive(Clone, Debug)]
1734pub enum SHJStreamState {
1735    /// Indicates that the next step should pull from the right side of the join.
1736    PullRight,
1737
1738    /// Indicates that the next step should pull from the left side of the join.
1739    PullLeft,
1740
1741    /// State representing that the right side of the join has been fully processed.
1742    RightExhausted,
1743
1744    /// State representing that the left side of the join has been fully processed.
1745    LeftExhausted,
1746
1747    /// Represents a state where both sides of the join are exhausted.
1748    ///
1749    /// The `final_result` field indicates whether the join operation has
1750    /// produced a final result or not.
1751    BothExhausted { final_result: bool },
1752}
1753
1754#[cfg(test)]
1755mod tests {
1756    use std::collections::HashMap;
1757    use std::sync::{LazyLock, Mutex};
1758
1759    use super::*;
1760    use crate::joins::test_utils::{
1761        build_sides_record_batches, compare_batches, complicated_filter,
1762        create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
1763        join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
1764        partitioned_sym_join_with_filter, split_record_batches,
1765    };
1766
1767    use arrow::compute::SortOptions;
1768    use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
1769    use datafusion_common::ScalarValue;
1770    use datafusion_execution::config::SessionConfig;
1771    use datafusion_expr::Operator;
1772    use datafusion_physical_expr::expressions::{Column, binary, col, lit};
1773    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
1774
1775    use rstest::*;
1776
1777    const TABLE_SIZE: i32 = 30;
1778
1779    type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size)
1780    type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); // (left, right)
1781
1782    // Cache for storing tables
1783    static TABLE_CACHE: LazyLock<Mutex<HashMap<TableKey, TableValue>>> =
1784        LazyLock::new(|| Mutex::new(HashMap::new()));
1785
1786    fn get_or_create_table(
1787        cardinality: (i32, i32),
1788        batch_size: usize,
1789    ) -> Result<TableValue> {
1790        {
1791            let cache = TABLE_CACHE.lock().unwrap();
1792            if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
1793                return Ok(table.clone());
1794            }
1795        }
1796
1797        // If not, create the table
1798        let (left_batch, right_batch) =
1799            build_sides_record_batches(TABLE_SIZE, cardinality)?;
1800
1801        let (left_partition, right_partition) = (
1802            split_record_batches(&left_batch, batch_size)?,
1803            split_record_batches(&right_batch, batch_size)?,
1804        );
1805
1806        // Lock the cache again and store the table
1807        let mut cache = TABLE_CACHE.lock().unwrap();
1808
1809        // Store the table in the cache
1810        cache.insert(
1811            (cardinality.0, cardinality.1, batch_size),
1812            (left_partition.clone(), right_partition.clone()),
1813        );
1814
1815        Ok((left_partition, right_partition))
1816    }
1817
1818    pub async fn experiment(
1819        left: Arc<dyn ExecutionPlan>,
1820        right: Arc<dyn ExecutionPlan>,
1821        filter: Option<JoinFilter>,
1822        join_type: JoinType,
1823        on: JoinOn,
1824        task_ctx: Arc<TaskContext>,
1825    ) -> Result<()> {
1826        let first_batches = partitioned_sym_join_with_filter(
1827            Arc::clone(&left),
1828            Arc::clone(&right),
1829            on.clone(),
1830            filter.clone(),
1831            &join_type,
1832            NullEquality::NullEqualsNothing,
1833            Arc::clone(&task_ctx),
1834        )
1835        .await?;
1836        let second_batches = partitioned_hash_join_with_filter(
1837            left,
1838            right,
1839            on,
1840            filter,
1841            &join_type,
1842            NullEquality::NullEqualsNothing,
1843            task_ctx,
1844        )
1845        .await?;
1846        compare_batches(&first_batches, &second_batches);
1847        Ok(())
1848    }
1849
1850    #[rstest]
1851    #[tokio::test(flavor = "multi_thread")]
1852    async fn complex_join_all_one_ascending_numeric(
1853        #[values(
1854            JoinType::Inner,
1855            JoinType::Left,
1856            JoinType::Right,
1857            JoinType::RightSemi,
1858            JoinType::LeftSemi,
1859            JoinType::LeftAnti,
1860            JoinType::LeftMark,
1861            JoinType::RightAnti,
1862            JoinType::RightMark,
1863            JoinType::Full
1864        )]
1865        join_type: JoinType,
1866        #[values(
1867        (4, 5),
1868        (12, 17),
1869        )]
1870        cardinality: (i32, i32),
1871    ) -> Result<()> {
1872        // a + b > c + 10 AND a + b < c + 100
1873        let task_ctx = Arc::new(TaskContext::default());
1874
1875        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
1876
1877        let left_schema = &left_partition[0].schema();
1878        let right_schema = &right_partition[0].schema();
1879
1880        let left_sorted = [PhysicalSortExpr {
1881            expr: binary(
1882                col("la1", left_schema)?,
1883                Operator::Plus,
1884                col("la2", left_schema)?,
1885                left_schema,
1886            )?,
1887            options: SortOptions::default(),
1888        }]
1889        .into();
1890        let right_sorted = [PhysicalSortExpr {
1891            expr: col("ra1", right_schema)?,
1892            options: SortOptions::default(),
1893        }]
1894        .into();
1895        let (left, right) = create_memory_table(
1896            left_partition,
1897            right_partition,
1898            vec![left_sorted],
1899            vec![right_sorted],
1900        )?;
1901
1902        let on = vec![(
1903            binary(
1904                col("lc1", left_schema)?,
1905                Operator::Plus,
1906                lit(ScalarValue::Int32(Some(1))),
1907                left_schema,
1908            )?,
1909            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1910        )];
1911
1912        let intermediate_schema = Schema::new(vec![
1913            Field::new("0", DataType::Int32, true),
1914            Field::new("1", DataType::Int32, true),
1915            Field::new("2", DataType::Int32, true),
1916        ]);
1917        let filter_expr = complicated_filter(&intermediate_schema)?;
1918        let column_indices = vec![
1919            ColumnIndex {
1920                index: left_schema.index_of("la1")?,
1921                side: JoinSide::Left,
1922            },
1923            ColumnIndex {
1924                index: left_schema.index_of("la2")?,
1925                side: JoinSide::Left,
1926            },
1927            ColumnIndex {
1928                index: right_schema.index_of("ra1")?,
1929                side: JoinSide::Right,
1930            },
1931        ];
1932        let filter =
1933            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1934
1935        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1936        Ok(())
1937    }
1938
1939    #[rstest]
1940    #[tokio::test(flavor = "multi_thread")]
1941    async fn join_all_one_ascending_numeric(
1942        #[values(
1943            JoinType::Inner,
1944            JoinType::Left,
1945            JoinType::Right,
1946            JoinType::RightSemi,
1947            JoinType::LeftSemi,
1948            JoinType::LeftAnti,
1949            JoinType::LeftMark,
1950            JoinType::RightAnti,
1951            JoinType::RightMark,
1952            JoinType::Full
1953        )]
1954        join_type: JoinType,
1955        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1956    ) -> Result<()> {
1957        let task_ctx = Arc::new(TaskContext::default());
1958        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1959
1960        let left_schema = &left_partition[0].schema();
1961        let right_schema = &right_partition[0].schema();
1962
1963        let left_sorted = [PhysicalSortExpr {
1964            expr: col("la1", left_schema)?,
1965            options: SortOptions::default(),
1966        }]
1967        .into();
1968        let right_sorted = [PhysicalSortExpr {
1969            expr: col("ra1", right_schema)?,
1970            options: SortOptions::default(),
1971        }]
1972        .into();
1973        let (left, right) = create_memory_table(
1974            left_partition,
1975            right_partition,
1976            vec![left_sorted],
1977            vec![right_sorted],
1978        )?;
1979
1980        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
1981
1982        let intermediate_schema = Schema::new(vec![
1983            Field::new("left", DataType::Int32, true),
1984            Field::new("right", DataType::Int32, true),
1985        ]);
1986        let filter_expr = join_expr_tests_fixture_i32(
1987            case_expr,
1988            col("left", &intermediate_schema)?,
1989            col("right", &intermediate_schema)?,
1990        );
1991        let column_indices = vec![
1992            ColumnIndex {
1993                index: 0,
1994                side: JoinSide::Left,
1995            },
1996            ColumnIndex {
1997                index: 0,
1998                side: JoinSide::Right,
1999            },
2000        ];
2001        let filter =
2002            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2003
2004        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2005        Ok(())
2006    }
2007
2008    #[rstest]
2009    #[tokio::test(flavor = "multi_thread")]
2010    async fn join_without_sort_information(
2011        #[values(
2012            JoinType::Inner,
2013            JoinType::Left,
2014            JoinType::Right,
2015            JoinType::RightSemi,
2016            JoinType::LeftSemi,
2017            JoinType::LeftAnti,
2018            JoinType::LeftMark,
2019            JoinType::RightAnti,
2020            JoinType::RightMark,
2021            JoinType::Full
2022        )]
2023        join_type: JoinType,
2024        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2025    ) -> Result<()> {
2026        let task_ctx = Arc::new(TaskContext::default());
2027        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
2028
2029        let left_schema = &left_partition[0].schema();
2030        let right_schema = &right_partition[0].schema();
2031        let (left, right) =
2032            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2033
2034        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2035
2036        let intermediate_schema = Schema::new(vec![
2037            Field::new("left", DataType::Int32, true),
2038            Field::new("right", DataType::Int32, true),
2039        ]);
2040        let filter_expr = join_expr_tests_fixture_i32(
2041            case_expr,
2042            col("left", &intermediate_schema)?,
2043            col("right", &intermediate_schema)?,
2044        );
2045        let column_indices = vec![
2046            ColumnIndex {
2047                index: 5,
2048                side: JoinSide::Left,
2049            },
2050            ColumnIndex {
2051                index: 5,
2052                side: JoinSide::Right,
2053            },
2054        ];
2055        let filter =
2056            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2057
2058        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2059        Ok(())
2060    }
2061
2062    #[rstest]
2063    #[tokio::test(flavor = "multi_thread")]
2064    async fn join_without_filter(
2065        #[values(
2066            JoinType::Inner,
2067            JoinType::Left,
2068            JoinType::Right,
2069            JoinType::RightSemi,
2070            JoinType::LeftSemi,
2071            JoinType::LeftAnti,
2072            JoinType::LeftMark,
2073            JoinType::RightAnti,
2074            JoinType::RightMark,
2075            JoinType::Full
2076        )]
2077        join_type: JoinType,
2078    ) -> Result<()> {
2079        let task_ctx = Arc::new(TaskContext::default());
2080        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2081        let left_schema = &left_partition[0].schema();
2082        let right_schema = &right_partition[0].schema();
2083        let (left, right) =
2084            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2085
2086        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2087        experiment(left, right, None, join_type, on, task_ctx).await?;
2088        Ok(())
2089    }
2090
2091    #[rstest]
2092    #[tokio::test(flavor = "multi_thread")]
2093    async fn join_all_one_descending_numeric_particular(
2094        #[values(
2095            JoinType::Inner,
2096            JoinType::Left,
2097            JoinType::Right,
2098            JoinType::RightSemi,
2099            JoinType::LeftSemi,
2100            JoinType::LeftAnti,
2101            JoinType::LeftMark,
2102            JoinType::RightAnti,
2103            JoinType::RightMark,
2104            JoinType::Full
2105        )]
2106        join_type: JoinType,
2107        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2108    ) -> Result<()> {
2109        let task_ctx = Arc::new(TaskContext::default());
2110        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2111        let left_schema = &left_partition[0].schema();
2112        let right_schema = &right_partition[0].schema();
2113        let left_sorted = [PhysicalSortExpr {
2114            expr: col("la1_des", left_schema)?,
2115            options: SortOptions {
2116                descending: true,
2117                nulls_first: true,
2118            },
2119        }]
2120        .into();
2121        let right_sorted = [PhysicalSortExpr {
2122            expr: col("ra1_des", right_schema)?,
2123            options: SortOptions {
2124                descending: true,
2125                nulls_first: true,
2126            },
2127        }]
2128        .into();
2129        let (left, right) = create_memory_table(
2130            left_partition,
2131            right_partition,
2132            vec![left_sorted],
2133            vec![right_sorted],
2134        )?;
2135
2136        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2137
2138        let intermediate_schema = Schema::new(vec![
2139            Field::new("left", DataType::Int32, true),
2140            Field::new("right", DataType::Int32, true),
2141        ]);
2142        let filter_expr = join_expr_tests_fixture_i32(
2143            case_expr,
2144            col("left", &intermediate_schema)?,
2145            col("right", &intermediate_schema)?,
2146        );
2147        let column_indices = vec![
2148            ColumnIndex {
2149                index: 5,
2150                side: JoinSide::Left,
2151            },
2152            ColumnIndex {
2153                index: 5,
2154                side: JoinSide::Right,
2155            },
2156        ];
2157        let filter =
2158            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2159
2160        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2161        Ok(())
2162    }
2163
2164    #[tokio::test(flavor = "multi_thread")]
2165    async fn build_null_columns_first() -> Result<()> {
2166        let join_type = JoinType::Full;
2167        let case_expr = 1;
2168        let session_config = SessionConfig::new().with_repartition_joins(false);
2169        let task_ctx = TaskContext::default().with_session_config(session_config);
2170        let task_ctx = Arc::new(task_ctx);
2171        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2172        let left_schema = &left_partition[0].schema();
2173        let right_schema = &right_partition[0].schema();
2174        let left_sorted = [PhysicalSortExpr {
2175            expr: col("l_asc_null_first", left_schema)?,
2176            options: SortOptions {
2177                descending: false,
2178                nulls_first: true,
2179            },
2180        }]
2181        .into();
2182        let right_sorted = [PhysicalSortExpr {
2183            expr: col("r_asc_null_first", right_schema)?,
2184            options: SortOptions {
2185                descending: false,
2186                nulls_first: true,
2187            },
2188        }]
2189        .into();
2190        let (left, right) = create_memory_table(
2191            left_partition,
2192            right_partition,
2193            vec![left_sorted],
2194            vec![right_sorted],
2195        )?;
2196
2197        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2198
2199        let intermediate_schema = Schema::new(vec![
2200            Field::new("left", DataType::Int32, true),
2201            Field::new("right", DataType::Int32, true),
2202        ]);
2203        let filter_expr = join_expr_tests_fixture_i32(
2204            case_expr,
2205            col("left", &intermediate_schema)?,
2206            col("right", &intermediate_schema)?,
2207        );
2208        let column_indices = vec![
2209            ColumnIndex {
2210                index: 6,
2211                side: JoinSide::Left,
2212            },
2213            ColumnIndex {
2214                index: 6,
2215                side: JoinSide::Right,
2216            },
2217        ];
2218        let filter =
2219            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2220        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2221        Ok(())
2222    }
2223
2224    #[tokio::test(flavor = "multi_thread")]
2225    async fn build_null_columns_last() -> Result<()> {
2226        let join_type = JoinType::Full;
2227        let case_expr = 1;
2228        let session_config = SessionConfig::new().with_repartition_joins(false);
2229        let task_ctx = TaskContext::default().with_session_config(session_config);
2230        let task_ctx = Arc::new(task_ctx);
2231        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2232
2233        let left_schema = &left_partition[0].schema();
2234        let right_schema = &right_partition[0].schema();
2235        let left_sorted = [PhysicalSortExpr {
2236            expr: col("l_asc_null_last", left_schema)?,
2237            options: SortOptions {
2238                descending: false,
2239                nulls_first: false,
2240            },
2241        }]
2242        .into();
2243        let right_sorted = [PhysicalSortExpr {
2244            expr: col("r_asc_null_last", right_schema)?,
2245            options: SortOptions {
2246                descending: false,
2247                nulls_first: false,
2248            },
2249        }]
2250        .into();
2251        let (left, right) = create_memory_table(
2252            left_partition,
2253            right_partition,
2254            vec![left_sorted],
2255            vec![right_sorted],
2256        )?;
2257
2258        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2259
2260        let intermediate_schema = Schema::new(vec![
2261            Field::new("left", DataType::Int32, true),
2262            Field::new("right", DataType::Int32, true),
2263        ]);
2264        let filter_expr = join_expr_tests_fixture_i32(
2265            case_expr,
2266            col("left", &intermediate_schema)?,
2267            col("right", &intermediate_schema)?,
2268        );
2269        let column_indices = vec![
2270            ColumnIndex {
2271                index: 7,
2272                side: JoinSide::Left,
2273            },
2274            ColumnIndex {
2275                index: 7,
2276                side: JoinSide::Right,
2277            },
2278        ];
2279        let filter =
2280            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2281
2282        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2283        Ok(())
2284    }
2285
2286    #[tokio::test(flavor = "multi_thread")]
2287    async fn build_null_columns_first_descending() -> Result<()> {
2288        let join_type = JoinType::Full;
2289        let cardinality = (10, 11);
2290        let case_expr = 1;
2291        let session_config = SessionConfig::new().with_repartition_joins(false);
2292        let task_ctx = TaskContext::default().with_session_config(session_config);
2293        let task_ctx = Arc::new(task_ctx);
2294        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2295
2296        let left_schema = &left_partition[0].schema();
2297        let right_schema = &right_partition[0].schema();
2298        let left_sorted = [PhysicalSortExpr {
2299            expr: col("l_desc_null_first", left_schema)?,
2300            options: SortOptions {
2301                descending: true,
2302                nulls_first: true,
2303            },
2304        }]
2305        .into();
2306        let right_sorted = [PhysicalSortExpr {
2307            expr: col("r_desc_null_first", right_schema)?,
2308            options: SortOptions {
2309                descending: true,
2310                nulls_first: true,
2311            },
2312        }]
2313        .into();
2314        let (left, right) = create_memory_table(
2315            left_partition,
2316            right_partition,
2317            vec![left_sorted],
2318            vec![right_sorted],
2319        )?;
2320
2321        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2322
2323        let intermediate_schema = Schema::new(vec![
2324            Field::new("left", DataType::Int32, true),
2325            Field::new("right", DataType::Int32, true),
2326        ]);
2327        let filter_expr = join_expr_tests_fixture_i32(
2328            case_expr,
2329            col("left", &intermediate_schema)?,
2330            col("right", &intermediate_schema)?,
2331        );
2332        let column_indices = vec![
2333            ColumnIndex {
2334                index: 8,
2335                side: JoinSide::Left,
2336            },
2337            ColumnIndex {
2338                index: 8,
2339                side: JoinSide::Right,
2340            },
2341        ];
2342        let filter =
2343            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2344
2345        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2346        Ok(())
2347    }
2348
2349    #[tokio::test(flavor = "multi_thread")]
2350    async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
2351        let cardinality = (3, 4);
2352        let join_type = JoinType::Full;
2353
2354        // a + b > c + 10 AND a + b < c + 100
2355        let session_config = SessionConfig::new().with_repartition_joins(false);
2356        let task_ctx = TaskContext::default().with_session_config(session_config);
2357        let task_ctx = Arc::new(task_ctx);
2358        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2359
2360        let left_schema = &left_partition[0].schema();
2361        let right_schema = &right_partition[0].schema();
2362        let left_sorted = [PhysicalSortExpr {
2363            expr: col("la1", left_schema)?,
2364            options: SortOptions::default(),
2365        }]
2366        .into();
2367        let right_sorted = [PhysicalSortExpr {
2368            expr: col("ra1", right_schema)?,
2369            options: SortOptions::default(),
2370        }]
2371        .into();
2372        let (left, right) = create_memory_table(
2373            left_partition,
2374            right_partition,
2375            vec![left_sorted],
2376            vec![right_sorted],
2377        )?;
2378
2379        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2380
2381        let intermediate_schema = Schema::new(vec![
2382            Field::new("0", DataType::Int32, true),
2383            Field::new("1", DataType::Int32, true),
2384            Field::new("2", DataType::Int32, true),
2385        ]);
2386        let filter_expr = complicated_filter(&intermediate_schema)?;
2387        let column_indices = vec![
2388            ColumnIndex {
2389                index: 0,
2390                side: JoinSide::Left,
2391            },
2392            ColumnIndex {
2393                index: 4,
2394                side: JoinSide::Left,
2395            },
2396            ColumnIndex {
2397                index: 0,
2398                side: JoinSide::Right,
2399            },
2400        ];
2401        let filter =
2402            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2403
2404        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2405        Ok(())
2406    }
2407
2408    #[tokio::test(flavor = "multi_thread")]
2409    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
2410        let cardinality = (3, 4);
2411        let join_type = JoinType::Full;
2412
2413        // a + b > c + 10 AND a + b < c + 100
2414        let config = SessionConfig::new().with_repartition_joins(false);
2415        // let session_ctx = SessionContext::with_config(config);
2416        // let task_ctx = session_ctx.task_ctx();
2417        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
2418        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2419        let left_schema = &left_partition[0].schema();
2420        let right_schema = &right_partition[0].schema();
2421        let left_sorted = vec![
2422            [PhysicalSortExpr {
2423                expr: col("la1", left_schema)?,
2424                options: SortOptions::default(),
2425            }]
2426            .into(),
2427            [PhysicalSortExpr {
2428                expr: col("la2", left_schema)?,
2429                options: SortOptions::default(),
2430            }]
2431            .into(),
2432        ];
2433
2434        let right_sorted = [PhysicalSortExpr {
2435            expr: col("ra1", right_schema)?,
2436            options: SortOptions::default(),
2437        }]
2438        .into();
2439
2440        let (left, right) = create_memory_table(
2441            left_partition,
2442            right_partition,
2443            left_sorted,
2444            vec![right_sorted],
2445        )?;
2446
2447        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2448
2449        let intermediate_schema = Schema::new(vec![
2450            Field::new("0", DataType::Int32, true),
2451            Field::new("1", DataType::Int32, true),
2452            Field::new("2", DataType::Int32, true),
2453        ]);
2454        let filter_expr = complicated_filter(&intermediate_schema)?;
2455        let column_indices = vec![
2456            ColumnIndex {
2457                index: 0,
2458                side: JoinSide::Left,
2459            },
2460            ColumnIndex {
2461                index: 4,
2462                side: JoinSide::Left,
2463            },
2464            ColumnIndex {
2465                index: 0,
2466                side: JoinSide::Right,
2467            },
2468        ];
2469        let filter =
2470            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2471
2472        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2473        Ok(())
2474    }
2475
2476    #[rstest]
2477    #[tokio::test(flavor = "multi_thread")]
2478    async fn testing_with_temporal_columns(
2479        #[values(
2480            JoinType::Inner,
2481            JoinType::Left,
2482            JoinType::Right,
2483            JoinType::RightSemi,
2484            JoinType::LeftSemi,
2485            JoinType::LeftAnti,
2486            JoinType::LeftMark,
2487            JoinType::RightAnti,
2488            JoinType::RightMark,
2489            JoinType::Full
2490        )]
2491        join_type: JoinType,
2492        #[values(
2493            (4, 5),
2494            (12, 17),
2495        )]
2496        cardinality: (i32, i32),
2497        #[values(0, 1, 2)] case_expr: usize,
2498    ) -> Result<()> {
2499        let session_config = SessionConfig::new().with_repartition_joins(false);
2500        let task_ctx = TaskContext::default().with_session_config(session_config);
2501        let task_ctx = Arc::new(task_ctx);
2502        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2503
2504        let left_schema = &left_partition[0].schema();
2505        let right_schema = &right_partition[0].schema();
2506        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2507        let left_sorted = [PhysicalSortExpr {
2508            expr: col("lt1", left_schema)?,
2509            options: SortOptions {
2510                descending: false,
2511                nulls_first: true,
2512            },
2513        }]
2514        .into();
2515        let right_sorted = [PhysicalSortExpr {
2516            expr: col("rt1", right_schema)?,
2517            options: SortOptions {
2518                descending: false,
2519                nulls_first: true,
2520            },
2521        }]
2522        .into();
2523        let (left, right) = create_memory_table(
2524            left_partition,
2525            right_partition,
2526            vec![left_sorted],
2527            vec![right_sorted],
2528        )?;
2529        let intermediate_schema = Schema::new(vec![
2530            Field::new(
2531                "left",
2532                DataType::Timestamp(TimeUnit::Millisecond, None),
2533                false,
2534            ),
2535            Field::new(
2536                "right",
2537                DataType::Timestamp(TimeUnit::Millisecond, None),
2538                false,
2539            ),
2540        ]);
2541        let filter_expr = join_expr_tests_fixture_temporal(
2542            case_expr,
2543            col("left", &intermediate_schema)?,
2544            col("right", &intermediate_schema)?,
2545            &intermediate_schema,
2546        )?;
2547        let column_indices = vec![
2548            ColumnIndex {
2549                index: 3,
2550                side: JoinSide::Left,
2551            },
2552            ColumnIndex {
2553                index: 3,
2554                side: JoinSide::Right,
2555            },
2556        ];
2557        let filter =
2558            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2559        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2560        Ok(())
2561    }
2562
2563    #[rstest]
2564    #[tokio::test(flavor = "multi_thread")]
2565    async fn test_with_interval_columns(
2566        #[values(
2567            JoinType::Inner,
2568            JoinType::Left,
2569            JoinType::Right,
2570            JoinType::RightSemi,
2571            JoinType::LeftSemi,
2572            JoinType::LeftAnti,
2573            JoinType::LeftMark,
2574            JoinType::RightAnti,
2575            JoinType::RightMark,
2576            JoinType::Full
2577        )]
2578        join_type: JoinType,
2579        #[values(
2580            (4, 5),
2581            (12, 17),
2582        )]
2583        cardinality: (i32, i32),
2584    ) -> Result<()> {
2585        let session_config = SessionConfig::new().with_repartition_joins(false);
2586        let task_ctx = TaskContext::default().with_session_config(session_config);
2587        let task_ctx = Arc::new(task_ctx);
2588        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2589
2590        let left_schema = &left_partition[0].schema();
2591        let right_schema = &right_partition[0].schema();
2592        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2593        let left_sorted = [PhysicalSortExpr {
2594            expr: col("li1", left_schema)?,
2595            options: SortOptions {
2596                descending: false,
2597                nulls_first: true,
2598            },
2599        }]
2600        .into();
2601        let right_sorted = [PhysicalSortExpr {
2602            expr: col("ri1", right_schema)?,
2603            options: SortOptions {
2604                descending: false,
2605                nulls_first: true,
2606            },
2607        }]
2608        .into();
2609        let (left, right) = create_memory_table(
2610            left_partition,
2611            right_partition,
2612            vec![left_sorted],
2613            vec![right_sorted],
2614        )?;
2615        let intermediate_schema = Schema::new(vec![
2616            Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
2617            Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
2618        ]);
2619        let filter_expr = join_expr_tests_fixture_temporal(
2620            0,
2621            col("left", &intermediate_schema)?,
2622            col("right", &intermediate_schema)?,
2623            &intermediate_schema,
2624        )?;
2625        let column_indices = vec![
2626            ColumnIndex {
2627                index: 9,
2628                side: JoinSide::Left,
2629            },
2630            ColumnIndex {
2631                index: 9,
2632                side: JoinSide::Right,
2633            },
2634        ];
2635        let filter =
2636            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2637        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2638
2639        Ok(())
2640    }
2641
2642    #[rstest]
2643    #[tokio::test(flavor = "multi_thread")]
2644    async fn testing_ascending_float_pruning(
2645        #[values(
2646            JoinType::Inner,
2647            JoinType::Left,
2648            JoinType::Right,
2649            JoinType::RightSemi,
2650            JoinType::LeftSemi,
2651            JoinType::LeftAnti,
2652            JoinType::LeftMark,
2653            JoinType::RightAnti,
2654            JoinType::RightMark,
2655            JoinType::Full
2656        )]
2657        join_type: JoinType,
2658        #[values(
2659            (4, 5),
2660            (12, 17),
2661        )]
2662        cardinality: (i32, i32),
2663        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2664    ) -> Result<()> {
2665        let session_config = SessionConfig::new().with_repartition_joins(false);
2666        let task_ctx = TaskContext::default().with_session_config(session_config);
2667        let task_ctx = Arc::new(task_ctx);
2668        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2669
2670        let left_schema = &left_partition[0].schema();
2671        let right_schema = &right_partition[0].schema();
2672        let left_sorted = [PhysicalSortExpr {
2673            expr: col("l_float", left_schema)?,
2674            options: SortOptions::default(),
2675        }]
2676        .into();
2677        let right_sorted = [PhysicalSortExpr {
2678            expr: col("r_float", right_schema)?,
2679            options: SortOptions::default(),
2680        }]
2681        .into();
2682        let (left, right) = create_memory_table(
2683            left_partition,
2684            right_partition,
2685            vec![left_sorted],
2686            vec![right_sorted],
2687        )?;
2688
2689        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2690
2691        let intermediate_schema = Schema::new(vec![
2692            Field::new("left", DataType::Float64, true),
2693            Field::new("right", DataType::Float64, true),
2694        ]);
2695        let filter_expr = join_expr_tests_fixture_f64(
2696            case_expr,
2697            col("left", &intermediate_schema)?,
2698            col("right", &intermediate_schema)?,
2699        );
2700        let column_indices = vec![
2701            ColumnIndex {
2702                index: 10, // l_float
2703                side: JoinSide::Left,
2704            },
2705            ColumnIndex {
2706                index: 10, // r_float
2707                side: JoinSide::Right,
2708            },
2709        ];
2710        let filter =
2711            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2712
2713        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2714        Ok(())
2715    }
2716}