Skip to main content

datafusion_physical_plan/joins/
symmetric_hash_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the symmetric hash join algorithm with range-based
19//! data pruning to join two (potentially infinite) streams.
20//!
21//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate
22//! output ordering) and produces the join output according to the given join
23//! type and other options.
24//!
25//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations
26//! for both its children.
27
28use std::fmt::{self, Debug};
29use std::mem::{size_of, size_of_val};
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use std::vec;
33
34use crate::check_if_same_properties;
35use crate::common::SharedMemoryReservation;
36use crate::execution_plan::{boundedness_from_children, emission_type_from_children};
37use crate::joins::stream_join_utils::{
38    PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
39    calculate_filter_expr_intervals, combine_two_batches,
40    convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
41    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
42};
43use crate::joins::utils::{
44    BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn,
45    JoinOnRef, NoopBatchTransformer, StatefulStreamResult, apply_join_filter_to_indices,
46    build_batch_from_indices, build_join_schema, check_join_is_valid, equal_rows_arr,
47    symmetric_join_output_partitioning, update_hash,
48};
49use crate::projection::{
50    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
51    physical_to_column_exprs, update_join_filter, update_join_on,
52};
53use crate::stream::EmptyRecordBatchStream;
54use crate::{
55    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
56    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
57    joins::StreamJoinPartitionMode,
58    metrics::{ExecutionPlanMetricsSet, MetricsSet},
59};
60
61use arrow::array::{
62    ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
63    UInt64Array,
64};
65use arrow::compute::concat_batches;
66use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
67use arrow::record_batch::RecordBatch;
68use datafusion_common::hash_utils::create_hashes;
69use datafusion_common::utils::bisect;
70use datafusion_common::{
71    HashSet, JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err,
72    plan_err,
73};
74use datafusion_execution::TaskContext;
75use datafusion_execution::memory_pool::MemoryConsumer;
76use datafusion_expr::interval_arithmetic::Interval;
77use datafusion_physical_expr::equivalence::join_equivalence_properties;
78use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
79use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
80use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
81
82use datafusion_common::hash_utils::RandomState;
83use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
84use futures::{Stream, StreamExt, ready};
85
86const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
87
88/// A symmetric hash join with range conditions is when both streams are hashed on the
89/// join key and the resulting hash tables are used to join the streams.
90/// The join is considered symmetric because the hash table is built on the join keys from both
91/// streams, and the matching of rows is based on the values of the join keys in both streams.
92/// This type of join is efficient in streaming context as it allows for fast lookups in the hash
93/// table, rather than having to scan through one or both of the streams to find matching rows, also it
94/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions),
95/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming
96/// data without any memory issues.
97///
98/// For each input stream, create a hash table.
99///   - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets.
100///   - Test if input is equal to a predefined set of other inputs.
101///   - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch].
102///   - Try to prune other side (probe) with new [RecordBatch].
103///   - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.),
104///     output the [RecordBatch] when a pruning happens or at the end of the data.
105///
106///
107/// ``` text
108///                        +-------------------------+
109///                        |                         |
110///   left stream ---------|  Left OneSideHashJoiner |---+
111///                        |                         |   |
112///                        +-------------------------+   |
113///                                                      |
114///                                                      |--------- Joined output
115///                                                      |
116///                        +-------------------------+   |
117///                        |                         |   |
118///  right stream ---------| Right OneSideHashJoiner |---+
119///                        |                         |
120///                        +-------------------------+
121///
122/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic
123/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range.
124///
125///
126///               PROBE SIDE          BUILD SIDE
127///                 BUFFER              BUFFER
128///             +-------------+     +------------+
129///             |             |     |            |    Unjoinable
130///             |             |     |            |    Range
131///             |             |     |            |
132///             |             |  |---------------------------------
133///             |             |  |  |            |
134///             |             |  |  |            |
135///             |             | /   |            |
136///             |             | |   |            |
137///             |             | |   |            |
138///             |             | |   |            |
139///             |             | |   |            |
140///             |             | |   |            |    Joinable
141///             |             |/    |            |    Range
142///             |             ||    |            |
143///             |+-----------+||    |            |
144///             || Record    ||     |            |
145///             || Batch     ||     |            |
146///             |+-----------+||    |            |
147///             +-------------+\    +------------+
148///                             |
149///                             \
150///                              |---------------------------------
151///
152///  This happens when range conditions are provided on sorted columns. E.g.
153///
154///        SELECT * FROM left_table, right_table
155///        ON
156///          left_key = right_key AND
157///          left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR
158///
159/// or
160///       SELECT * FROM left_table, right_table
161///        ON
162///          left_key = right_key AND
163///          left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10
164///
165/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to
166/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the
167/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios)
168/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning
169/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" ,
170/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending)
171/// than that can be dropped from the inner buffer.
172/// ```
173#[derive(Debug, Clone)]
174pub struct SymmetricHashJoinExec {
175    /// Left side stream
176    pub(crate) left: Arc<dyn ExecutionPlan>,
177    /// Right side stream
178    pub(crate) right: Arc<dyn ExecutionPlan>,
179    /// Set of common columns used to join on
180    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
181    /// Filters applied when finding matching rows
182    pub(crate) filter: Option<JoinFilter>,
183    /// How the join is performed
184    pub(crate) join_type: JoinType,
185    /// Shares the `RandomState` for the hashing algorithm
186    random_state: RandomState,
187    /// Execution metrics
188    metrics: ExecutionPlanMetricsSet,
189    /// Information of index and left / right placement of columns
190    column_indices: Vec<ColumnIndex>,
191    /// Defines the null equality for the join.
192    pub(crate) null_equality: NullEquality,
193    /// Left side sort expression(s)
194    pub(crate) left_sort_exprs: Option<LexOrdering>,
195    /// Right side sort expression(s)
196    pub(crate) right_sort_exprs: Option<LexOrdering>,
197    /// Partition Mode
198    mode: StreamJoinPartitionMode,
199    /// Cache holding plan properties like equivalences, output partitioning etc.
200    cache: Arc<PlanProperties>,
201}
202
203impl SymmetricHashJoinExec {
204    /// Tries to create a new [SymmetricHashJoinExec].
205    /// # Error
206    /// This function errors when:
207    /// - It is not possible to join the left and right sides on keys `on`, or
208    /// - It fails to construct `SortedFilterExpr`s, or
209    /// - It fails to create the [ExprIntervalGraph].
210    #[expect(clippy::too_many_arguments)]
211    pub fn try_new(
212        left: Arc<dyn ExecutionPlan>,
213        right: Arc<dyn ExecutionPlan>,
214        on: JoinOn,
215        filter: Option<JoinFilter>,
216        join_type: &JoinType,
217        null_equality: NullEquality,
218        left_sort_exprs: Option<LexOrdering>,
219        right_sort_exprs: Option<LexOrdering>,
220        mode: StreamJoinPartitionMode,
221    ) -> Result<Self> {
222        let left_schema = left.schema();
223        let right_schema = right.schema();
224
225        // Error out if no "on" constraints are given:
226        if on.is_empty() {
227            return plan_err!(
228                "On constraints in SymmetricHashJoinExec should be non-empty"
229            );
230        }
231
232        // Check if the join is valid with the given on constraints:
233        check_join_is_valid(&left_schema, &right_schema, &on)?;
234
235        // Build the join schema from the left and right schemas:
236        let (schema, column_indices) =
237            build_join_schema(&left_schema, &right_schema, join_type);
238
239        // Initialize the random state for the join operation:
240        let random_state = RandomState::with_seed(0);
241        let schema = Arc::new(schema);
242        let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?;
243        Ok(SymmetricHashJoinExec {
244            left,
245            right,
246            on,
247            filter,
248            join_type: *join_type,
249            random_state,
250            metrics: ExecutionPlanMetricsSet::new(),
251            column_indices,
252            null_equality,
253            left_sort_exprs,
254            right_sort_exprs,
255            mode,
256            cache: Arc::new(cache),
257        })
258    }
259
260    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
261    fn compute_properties(
262        left: &Arc<dyn ExecutionPlan>,
263        right: &Arc<dyn ExecutionPlan>,
264        schema: SchemaRef,
265        join_type: JoinType,
266        join_on: JoinOnRef,
267    ) -> Result<PlanProperties> {
268        // Calculate equivalence properties:
269        let eq_properties = join_equivalence_properties(
270            left.equivalence_properties().clone(),
271            right.equivalence_properties().clone(),
272            &join_type,
273            schema,
274            &[false, false],
275            // Has alternating probe side
276            None,
277            join_on,
278        )?;
279
280        let output_partitioning =
281            symmetric_join_output_partitioning(left, right, &join_type)?;
282
283        Ok(PlanProperties::new(
284            eq_properties,
285            output_partitioning,
286            emission_type_from_children([left, right]),
287            boundedness_from_children([left, right]),
288        ))
289    }
290
291    /// left stream
292    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
293        &self.left
294    }
295
296    /// right stream
297    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
298        &self.right
299    }
300
301    /// Set of common columns used to join on
302    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
303        &self.on
304    }
305
306    /// Filters applied before join output
307    pub fn filter(&self) -> Option<&JoinFilter> {
308        self.filter.as_ref()
309    }
310
311    /// How the join is performed
312    pub fn join_type(&self) -> &JoinType {
313        &self.join_type
314    }
315
316    /// Get null_equality
317    pub fn null_equality(&self) -> NullEquality {
318        self.null_equality
319    }
320
321    /// Get partition mode
322    pub fn partition_mode(&self) -> StreamJoinPartitionMode {
323        self.mode
324    }
325
326    /// Get left_sort_exprs
327    pub fn left_sort_exprs(&self) -> Option<&LexOrdering> {
328        self.left_sort_exprs.as_ref()
329    }
330
331    /// Get right_sort_exprs
332    pub fn right_sort_exprs(&self) -> Option<&LexOrdering> {
333        self.right_sort_exprs.as_ref()
334    }
335
336    /// Check if order information covers every column in the filter expression.
337    pub fn check_if_order_information_available(&self) -> Result<bool> {
338        if let Some(filter) = self.filter() {
339            let left = self.left();
340            if let Some(left_ordering) = left.output_ordering() {
341                let right = self.right();
342                if let Some(right_ordering) = right.output_ordering() {
343                    let left_convertible = convert_sort_expr_with_filter_schema(
344                        &JoinSide::Left,
345                        filter,
346                        &left.schema(),
347                        &left_ordering[0],
348                    )?
349                    .is_some();
350                    let right_convertible = convert_sort_expr_with_filter_schema(
351                        &JoinSide::Right,
352                        filter,
353                        &right.schema(),
354                        &right_ordering[0],
355                    )?
356                    .is_some();
357                    return Ok(left_convertible && right_convertible);
358                }
359            }
360        }
361        Ok(false)
362    }
363
364    fn with_new_children_and_same_properties(
365        &self,
366        mut children: Vec<Arc<dyn ExecutionPlan>>,
367    ) -> Self {
368        let left = children.swap_remove(0);
369        let right = children.swap_remove(0);
370        Self {
371            left,
372            right,
373            metrics: ExecutionPlanMetricsSet::new(),
374            ..Self::clone(self)
375        }
376    }
377}
378
379impl DisplayAs for SymmetricHashJoinExec {
380    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
381        match t {
382            DisplayFormatType::Default | DisplayFormatType::Verbose => {
383                let display_filter = self.filter.as_ref().map_or_else(
384                    || "".to_string(),
385                    |f| format!(", filter={}", f.expression()),
386                );
387                let on = self
388                    .on
389                    .iter()
390                    .map(|(c1, c2)| format!("({c1}, {c2})"))
391                    .collect::<Vec<String>>()
392                    .join(", ");
393                write!(
394                    f,
395                    "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
396                    self.mode, self.join_type, on, display_filter
397                )
398            }
399            DisplayFormatType::TreeRender => {
400                let on = self
401                    .on
402                    .iter()
403                    .map(|(c1, c2)| {
404                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
405                    })
406                    .collect::<Vec<String>>()
407                    .join(", ");
408
409                writeln!(f, "mode={:?}", self.mode)?;
410                if *self.join_type() != JoinType::Inner {
411                    writeln!(f, "join_type={:?}", self.join_type)?;
412                }
413                writeln!(f, "on={on}")
414            }
415        }
416    }
417}
418
419impl ExecutionPlan for SymmetricHashJoinExec {
420    fn name(&self) -> &'static str {
421        "SymmetricHashJoinExec"
422    }
423
424    fn properties(&self) -> &Arc<PlanProperties> {
425        &self.cache
426    }
427
428    fn required_input_distribution(&self) -> Vec<Distribution> {
429        match self.mode {
430            StreamJoinPartitionMode::Partitioned => {
431                let (left_expr, right_expr) = self
432                    .on
433                    .iter()
434                    .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
435                    .unzip();
436                vec![
437                    Distribution::HashPartitioned(left_expr),
438                    Distribution::HashPartitioned(right_expr),
439                ]
440            }
441            StreamJoinPartitionMode::SinglePartition => {
442                vec![Distribution::SinglePartition, Distribution::SinglePartition]
443            }
444        }
445    }
446
447    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
448        vec![
449            self.left_sort_exprs
450                .as_ref()
451                .map(|e| OrderingRequirements::from(e.clone())),
452            self.right_sort_exprs
453                .as_ref()
454                .map(|e| OrderingRequirements::from(e.clone())),
455        ]
456    }
457
458    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
459        vec![&self.left, &self.right]
460    }
461
462    fn with_new_children(
463        self: Arc<Self>,
464        children: Vec<Arc<dyn ExecutionPlan>>,
465    ) -> Result<Arc<dyn ExecutionPlan>> {
466        check_if_same_properties!(self, children);
467        Ok(Arc::new(SymmetricHashJoinExec::try_new(
468            Arc::clone(&children[0]),
469            Arc::clone(&children[1]),
470            self.on.clone(),
471            self.filter.clone(),
472            &self.join_type,
473            self.null_equality,
474            self.left_sort_exprs.clone(),
475            self.right_sort_exprs.clone(),
476            self.mode,
477        )?))
478    }
479
480    fn metrics(&self) -> Option<MetricsSet> {
481        Some(self.metrics.clone_inner())
482    }
483
484    fn execute(
485        &self,
486        partition: usize,
487        context: Arc<TaskContext>,
488    ) -> Result<SendableRecordBatchStream> {
489        let left_partitions = self.left.output_partitioning().partition_count();
490        let right_partitions = self.right.output_partitioning().partition_count();
491        assert_eq_or_internal_err!(
492            left_partitions,
493            right_partitions,
494            "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
495                 consider using RepartitionExec"
496        );
497        // If `filter_state` and `filter` are both present, then calculate sorted
498        // filter expressions for both sides, and build an expression graph.
499        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
500            self.left_sort_exprs(),
501            self.right_sort_exprs(),
502            &self.filter,
503        ) {
504            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
505                let (left, right, graph) = prepare_sorted_exprs(
506                    filter,
507                    &self.left,
508                    &self.right,
509                    left_sort_exprs,
510                    right_sort_exprs,
511                )?;
512                (Some(left), Some(right), Some(graph))
513            }
514            // If `filter_state` or `filter` is not present, then return None
515            // for all three values:
516            _ => (None, None, None),
517        };
518
519        let (on_left, on_right) = self.on.iter().cloned().unzip();
520
521        let left_side_joiner =
522            OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
523        let right_side_joiner =
524            OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
525
526        let left_stream = self.left.execute(partition, Arc::clone(&context))?;
527
528        let right_stream = self.right.execute(partition, Arc::clone(&context))?;
529
530        let batch_size = context.session_config().batch_size();
531        let enforce_batch_size_in_joins =
532            context.session_config().enforce_batch_size_in_joins();
533
534        let reservation = Arc::new(
535            MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
536                .register(context.memory_pool()),
537        );
538        if let Some(g) = graph.as_ref() {
539            reservation.try_grow(g.size())?;
540        }
541
542        if enforce_batch_size_in_joins {
543            Ok(Box::pin(SymmetricHashJoinStream {
544                left_stream,
545                right_stream,
546                schema: self.schema(),
547                filter: self.filter.clone(),
548                join_type: self.join_type,
549                random_state: self.random_state.clone(),
550                left: left_side_joiner,
551                right: right_side_joiner,
552                column_indices: self.column_indices.clone(),
553                metrics: StreamJoinMetrics::new(partition, &self.metrics),
554                graph,
555                left_sorted_filter_expr,
556                right_sorted_filter_expr,
557                null_equality: self.null_equality,
558                state: SHJStreamState::PullRight,
559                reservation,
560                batch_transformer: BatchSplitter::new(batch_size),
561            }))
562        } else {
563            Ok(Box::pin(SymmetricHashJoinStream {
564                left_stream,
565                right_stream,
566                schema: self.schema(),
567                filter: self.filter.clone(),
568                join_type: self.join_type,
569                random_state: self.random_state.clone(),
570                left: left_side_joiner,
571                right: right_side_joiner,
572                column_indices: self.column_indices.clone(),
573                metrics: StreamJoinMetrics::new(partition, &self.metrics),
574                graph,
575                left_sorted_filter_expr,
576                right_sorted_filter_expr,
577                null_equality: self.null_equality,
578                state: SHJStreamState::PullRight,
579                reservation,
580                batch_transformer: NoopBatchTransformer::new(),
581            }))
582        }
583    }
584
585    /// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done,
586    /// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan.
587    /// Otherwise, it returns None.
588    fn try_swapping_with_projection(
589        &self,
590        projection: &ProjectionExec,
591    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
592        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
593        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
594        else {
595            return Ok(None);
596        };
597
598        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
599            self.left().schema().fields().len(),
600            &projection_as_columns,
601        );
602
603        if !join_allows_pushdown(
604            &projection_as_columns,
605            &self.schema(),
606            far_right_left_col_ind,
607            far_left_right_col_ind,
608        ) {
609            return Ok(None);
610        }
611
612        let Some(new_on) = update_join_on(
613            &projection_as_columns[0..=far_right_left_col_ind as _],
614            &projection_as_columns[far_left_right_col_ind as _..],
615            self.on(),
616            self.left().schema().fields().len(),
617        ) else {
618            return Ok(None);
619        };
620
621        let new_filter = if let Some(filter) = self.filter() {
622            match update_join_filter(
623                &projection_as_columns[0..=far_right_left_col_ind as _],
624                &projection_as_columns[far_left_right_col_ind as _..],
625                filter,
626                self.left().schema().fields().len(),
627            ) {
628                Some(updated_filter) => Some(updated_filter),
629                None => return Ok(None),
630            }
631        } else {
632            None
633        };
634
635        let (new_left, new_right) = new_join_children(
636            &projection_as_columns,
637            far_right_left_col_ind,
638            far_left_right_col_ind,
639            self.left(),
640            self.right(),
641        )?;
642
643        SymmetricHashJoinExec::try_new(
644            Arc::new(new_left),
645            Arc::new(new_right),
646            new_on,
647            new_filter,
648            self.join_type(),
649            self.null_equality(),
650            self.right().output_ordering().cloned(),
651            self.left().output_ordering().cloned(),
652            self.partition_mode(),
653        )
654        .map(|e| Some(Arc::new(e) as _))
655    }
656}
657
658/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
659struct SymmetricHashJoinStream<T> {
660    /// Input streams
661    left_stream: SendableRecordBatchStream,
662    right_stream: SendableRecordBatchStream,
663    /// Input schema
664    schema: Arc<Schema>,
665    /// join filter
666    filter: Option<JoinFilter>,
667    /// type of the join
668    join_type: JoinType,
669    // left hash joiner
670    left: OneSideHashJoiner,
671    /// right hash joiner
672    right: OneSideHashJoiner,
673    /// Information of index and left / right placement of columns
674    column_indices: Vec<ColumnIndex>,
675    // Expression graph for range pruning.
676    graph: Option<ExprIntervalGraph>,
677    // Left globally sorted filter expr
678    left_sorted_filter_expr: Option<SortedFilterExpr>,
679    // Right globally sorted filter expr
680    right_sorted_filter_expr: Option<SortedFilterExpr>,
681    /// Random state used for hashing initialization
682    random_state: RandomState,
683    /// Defines the null equality for the join.
684    null_equality: NullEquality,
685    /// Metrics
686    metrics: StreamJoinMetrics,
687    /// Memory reservation
688    reservation: SharedMemoryReservation,
689    /// State machine for input execution
690    state: SHJStreamState,
691    /// Transforms the output batch before returning.
692    batch_transformer: T,
693}
694
695impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
696    for SymmetricHashJoinStream<T>
697{
698    fn schema(&self) -> SchemaRef {
699        Arc::clone(&self.schema)
700    }
701}
702
703impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> {
704    type Item = Result<RecordBatch>;
705
706    fn poll_next(
707        mut self: std::pin::Pin<&mut Self>,
708        cx: &mut Context<'_>,
709    ) -> Poll<Option<Self::Item>> {
710        self.poll_next_impl(cx)
711    }
712}
713
714/// Determine the pruning length for `buffer`.
715///
716/// This function evaluates the build side filter expression, converts the
717/// result into an array and determines the pruning length by performing a
718/// binary search on the array.
719///
720/// # Arguments
721///
722/// * `buffer`: The record batch to be pruned.
723/// * `build_side_filter_expr`: The filter expression on the build side used
724///   to determine the pruning length.
725///
726/// # Returns
727///
728/// A [Result] object that contains the pruning length. The function will return
729/// an error if
730/// - there is an issue evaluating the build side filter expression;
731/// - there is an issue converting the build side filter expression into an array
732fn determine_prune_length(
733    buffer: &RecordBatch,
734    build_side_filter_expr: &SortedFilterExpr,
735) -> Result<usize> {
736    let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
737    let interval = build_side_filter_expr.interval();
738    // Evaluate the build side filter expression and convert it into an array
739    let batch_arr = origin_sorted_expr
740        .expr
741        .evaluate(buffer)?
742        .into_array(buffer.num_rows())?;
743
744    // Get the lower or upper interval based on the sort direction
745    let target = if origin_sorted_expr.options.descending {
746        interval.upper().clone()
747    } else {
748        interval.lower().clone()
749    };
750
751    // Perform binary search on the array to determine the length of the record batch to be pruned
752    bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
753}
754
755/// This method determines if the result of the join should be produced in the final step or not.
756///
757/// # Arguments
758///
759/// * `build_side` - Enum indicating the side of the join used as the build side.
760/// * `join_type` - Enum indicating the type of join to be performed.
761///
762/// # Returns
763///
764/// A boolean indicating whether the result of the join should be produced in the final step or not.
765/// The result will be true if the build side is JoinSide::Left and the join type is one of
766/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi.
767/// If the build side is JoinSide::Right, the result will be true if the join type
768/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi.
769fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
770    if build_side == JoinSide::Left {
771        matches!(
772            join_type,
773            JoinType::Left
774                | JoinType::LeftAnti
775                | JoinType::Full
776                | JoinType::LeftSemi
777                | JoinType::LeftMark
778        )
779    } else {
780        matches!(
781            join_type,
782            JoinType::Right
783                | JoinType::RightAnti
784                | JoinType::Full
785                | JoinType::RightSemi
786                | JoinType::RightMark
787        )
788    }
789}
790
791/// Calculate indices by join type.
792///
793/// This method returns a tuple of two arrays: build and probe indices.
794/// The length of both arrays will be the same.
795///
796/// # Arguments
797///
798/// * `build_side`: Join side which defines the build side.
799/// * `prune_length`: Length of the prune data.
800/// * `visited_rows`: Hash set of visited rows of the build side.
801/// * `deleted_offset`: Deleted offset of the build side.
802/// * `join_type`: The type of join to be performed.
803///
804/// # Returns
805///
806/// A tuple of two arrays of primitive types representing the build and probe indices.
807fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
808    build_side: JoinSide,
809    prune_length: usize,
810    visited_rows: &HashSet<usize>,
811    deleted_offset: usize,
812    join_type: JoinType,
813) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
814where
815    NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
816{
817    // Store the result in a tuple
818    let result = match (build_side, join_type) {
819        // For a mark join we “mark” each build‐side row with a dummy 0 in the probe‐side index
820        // if it ever matched. For example, if
821        //
822        // prune_length = 5
823        // deleted_offset = 0
824        // visited_rows = {1, 3}
825        //
826        // then we produce:
827        //
828        // build_indices = [0, 1, 2, 3, 4]
829        // probe_indices = [None, Some(0), None, Some(0), None]
830        //
831        // Example: for each build row i in [0..5):
832        //   – We always output its own index i in `build_indices`
833        //   – We output `Some(0)` in `probe_indices[i]` if row i was ever visited, else `None`
834        (JoinSide::Left, JoinType::LeftMark) => {
835            let build_indices = (0..prune_length)
836                .map(L::Native::from_usize)
837                .collect::<PrimitiveArray<L>>();
838            let probe_indices = (0..prune_length)
839                .map(|idx| {
840                    // For mark join we output a dummy index 0 to indicate the row had a match
841                    visited_rows
842                        .contains(&(idx + deleted_offset))
843                        .then_some(R::Native::from_usize(0).unwrap())
844                })
845                .collect();
846            (build_indices, probe_indices)
847        }
848        (JoinSide::Right, JoinType::RightMark) => {
849            let build_indices = (0..prune_length)
850                .map(L::Native::from_usize)
851                .collect::<PrimitiveArray<L>>();
852            let probe_indices = (0..prune_length)
853                .map(|idx| {
854                    // For mark join we output a dummy index 0 to indicate the row had a match
855                    visited_rows
856                        .contains(&(idx + deleted_offset))
857                        .then_some(R::Native::from_usize(0).unwrap())
858                })
859                .collect();
860            (build_indices, probe_indices)
861        }
862        // In the case of `Left` or `Right` join, or `Full` join, get the anti indices
863        (JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
864        | (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
865        | (_, JoinType::Full) => {
866            let build_unmatched_indices =
867                get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
868            let mut builder =
869                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
870            builder.append_nulls(build_unmatched_indices.len());
871            let probe_indices = builder.finish();
872            (build_unmatched_indices, probe_indices)
873        }
874        // In the case of `LeftSemi` or `RightSemi` join, get the semi indices
875        (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
876            let build_unmatched_indices =
877                get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
878            let mut builder =
879                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
880            builder.append_nulls(build_unmatched_indices.len());
881            let probe_indices = builder.finish();
882            (build_unmatched_indices, probe_indices)
883        }
884        // The case of other join types is not considered
885        _ => unreachable!(),
886    };
887    Ok(result)
888}
889
890/// This function produces unmatched record results based on the build side,
891/// join type and other parameters.
892///
893/// The method uses first `prune_length` rows from the build side input buffer
894/// to produce results.
895///
896/// # Arguments
897///
898/// * `output_schema` - The schema of the final output record batch.
899/// * `prune_length` - The length of the determined prune length.
900/// * `probe_schema` - The schema of the probe [RecordBatch].
901/// * `join_type` - The type of join to be performed.
902/// * `column_indices` - Indices of columns that are being joined.
903///
904/// # Returns
905///
906/// * `Option<RecordBatch>` - The final output record batch if required, otherwise [None].
907pub(crate) fn build_side_determined_results(
908    build_hash_joiner: &OneSideHashJoiner,
909    output_schema: &SchemaRef,
910    prune_length: usize,
911    probe_schema: SchemaRef,
912    join_type: JoinType,
913    column_indices: &[ColumnIndex],
914) -> Result<Option<RecordBatch>> {
915    // Check if we need to produce a result in the final output:
916    if prune_length > 0
917        && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
918    {
919        // Calculate the indices for build and probe sides based on join type and build side:
920        let (build_indices, probe_indices) = calculate_indices_by_join_type(
921            build_hash_joiner.build_side,
922            prune_length,
923            &build_hash_joiner.visited_rows,
924            build_hash_joiner.deleted_offset,
925            join_type,
926        )?;
927
928        // Create an empty probe record batch:
929        let empty_probe_batch = RecordBatch::new_empty(probe_schema);
930        // Build the final result from the indices of build and probe sides:
931        build_batch_from_indices(
932            output_schema.as_ref(),
933            &build_hash_joiner.input_buffer,
934            &empty_probe_batch,
935            &build_indices,
936            &probe_indices,
937            column_indices,
938            build_hash_joiner.build_side,
939            join_type,
940        )
941        .map(|batch| (batch.num_rows() > 0).then_some(batch))
942    } else {
943        // If we don't need to produce a result, return None
944        Ok(None)
945    }
946}
947
948/// This method performs a join between the build side input buffer and the probe side batch.
949///
950/// # Arguments
951///
952/// * `build_hash_joiner` - Build side hash joiner
953/// * `probe_hash_joiner` - Probe side hash joiner
954/// * `schema` - A reference to the schema of the output record batch.
955/// * `join_type` - The type of join to be performed.
956/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
957/// * `filter` - An optional filter on the join condition.
958/// * `probe_batch` - The second record batch to be joined.
959/// * `column_indices` - An array of columns to be selected for the result of the join.
960/// * `random_state` - The random state for the join.
961/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
962///
963/// # Returns
964///
965/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
966/// If the join type is one of the above four, the function will return [None].
967#[expect(clippy::too_many_arguments)]
968pub(crate) fn join_with_probe_batch(
969    build_hash_joiner: &mut OneSideHashJoiner,
970    probe_hash_joiner: &mut OneSideHashJoiner,
971    schema: &SchemaRef,
972    join_type: JoinType,
973    filter: Option<&JoinFilter>,
974    probe_batch: &RecordBatch,
975    column_indices: &[ColumnIndex],
976    random_state: &RandomState,
977    null_equality: NullEquality,
978) -> Result<Option<RecordBatch>> {
979    if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
980        return Ok(None);
981    }
982    let (build_indices, probe_indices) = lookup_join_hashmap(
983        &build_hash_joiner.hashmap,
984        &build_hash_joiner.input_buffer,
985        probe_batch,
986        &build_hash_joiner.on,
987        &probe_hash_joiner.on,
988        random_state,
989        null_equality,
990        &mut build_hash_joiner.hashes_buffer,
991        Some(build_hash_joiner.deleted_offset),
992    )?;
993
994    let (build_indices, probe_indices) = if let Some(filter) = filter {
995        apply_join_filter_to_indices(
996            &build_hash_joiner.input_buffer,
997            probe_batch,
998            build_indices,
999            probe_indices,
1000            filter,
1001            build_hash_joiner.build_side,
1002            None,
1003            join_type,
1004        )?
1005    } else {
1006        (build_indices, probe_indices)
1007    };
1008
1009    if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
1010        record_visited_indices(
1011            &mut build_hash_joiner.visited_rows,
1012            build_hash_joiner.deleted_offset,
1013            &build_indices,
1014        );
1015    }
1016    if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
1017        record_visited_indices(
1018            &mut probe_hash_joiner.visited_rows,
1019            probe_hash_joiner.offset,
1020            &probe_indices,
1021        );
1022    }
1023    if matches!(
1024        join_type,
1025        JoinType::LeftAnti
1026            | JoinType::RightAnti
1027            | JoinType::LeftSemi
1028            | JoinType::LeftMark
1029            | JoinType::RightSemi
1030            | JoinType::RightMark
1031    ) {
1032        Ok(None)
1033    } else {
1034        build_batch_from_indices(
1035            schema,
1036            &build_hash_joiner.input_buffer,
1037            probe_batch,
1038            &build_indices,
1039            &probe_indices,
1040            column_indices,
1041            build_hash_joiner.build_side,
1042            join_type,
1043        )
1044        .map(|batch| (batch.num_rows() > 0).then_some(batch))
1045    }
1046}
1047
1048/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential
1049/// hash collisions.
1050///
1051/// # Arguments
1052///
1053/// * `build_hashmap` - hashmap collected from build side data.
1054/// * `build_batch` - Build side record batch.
1055/// * `probe_batch` - Probe side record batch.
1056/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join.
1057/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
1058/// * `random_state` - The random state for the join.
1059/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
1060/// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
1061/// * `deleted_offset` - deleted offset for build side data.
1062///
1063/// # Returns
1064///
1065/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side,
1066/// matched by join key columns.
1067#[expect(clippy::too_many_arguments)]
1068fn lookup_join_hashmap(
1069    build_hashmap: &PruningJoinHashMap,
1070    build_batch: &RecordBatch,
1071    probe_batch: &RecordBatch,
1072    build_on: &[PhysicalExprRef],
1073    probe_on: &[PhysicalExprRef],
1074    random_state: &RandomState,
1075    null_equality: NullEquality,
1076    hashes_buffer: &mut Vec<u64>,
1077    deleted_offset: Option<usize>,
1078) -> Result<(UInt64Array, UInt32Array)> {
1079    let keys_values = evaluate_expressions_to_arrays(probe_on, probe_batch)?;
1080    let build_join_values = evaluate_expressions_to_arrays(build_on, build_batch)?;
1081
1082    hashes_buffer.clear();
1083    hashes_buffer.resize(probe_batch.num_rows(), 0);
1084    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1085
1086    // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm
1087    // will return build indices for each probe row in a reverse order as such:
1088    // Build Indices: [5, 4, 3]
1089    // Probe Indices: [1, 1, 1]
1090    //
1091    // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side.
1092    // Let's consider probe rows [0,1] as an example:
1093    //
1094    // When the probe iteration sequence is reversed, the following pairings can be derived:
1095    //
1096    // For probe row 1:
1097    //     (5, 1)
1098    //     (4, 1)
1099    //     (3, 1)
1100    //
1101    // For probe row 0:
1102    //     (5, 0)
1103    //     (4, 0)
1104    //     (3, 0)
1105    //
1106    // After reversing both sets of indices, we obtain reversed indices:
1107    //
1108    //     (3,0)
1109    //     (4,0)
1110    //     (5,0)
1111    //     (3,1)
1112    //     (4,1)
1113    //     (5,1)
1114    //
1115    // With this approach, the lexicographic order on both the probe side and the build side is preserved.
1116    let (mut matched_probe, mut matched_build) = build_hashmap.get_matched_indices(
1117        Box::new(hash_values.iter().enumerate().rev()),
1118        deleted_offset,
1119    );
1120
1121    matched_probe.reverse();
1122    matched_build.reverse();
1123
1124    let build_indices: UInt64Array = matched_build.into();
1125    let probe_indices: UInt32Array = matched_probe.into();
1126
1127    let (build_indices, probe_indices) = equal_rows_arr(
1128        &build_indices,
1129        &probe_indices,
1130        &build_join_values,
1131        &keys_values,
1132        null_equality,
1133    )?;
1134
1135    Ok((build_indices, probe_indices))
1136}
1137
1138pub struct OneSideHashJoiner {
1139    /// Build side
1140    build_side: JoinSide,
1141    /// Input record batch buffer
1142    pub input_buffer: RecordBatch,
1143    /// Columns from the side
1144    pub(crate) on: Vec<PhysicalExprRef>,
1145    /// Hashmap
1146    pub(crate) hashmap: PruningJoinHashMap,
1147    /// Reuse the hashes buffer
1148    pub(crate) hashes_buffer: Vec<u64>,
1149    /// Matched rows
1150    pub(crate) visited_rows: HashSet<usize>,
1151    /// Offset
1152    pub(crate) offset: usize,
1153    /// Deleted offset
1154    pub(crate) deleted_offset: usize,
1155}
1156
1157impl OneSideHashJoiner {
1158    pub fn size(&self) -> usize {
1159        let mut size = 0;
1160        size += size_of_val(self);
1161        size += size_of_val(&self.build_side);
1162        size += self.input_buffer.get_array_memory_size();
1163        size += size_of_val(&self.on);
1164        size += self.hashmap.size();
1165        size += self.hashes_buffer.capacity() * size_of::<u64>();
1166        size += self.visited_rows.capacity() * size_of::<usize>();
1167        size += size_of_val(&self.offset);
1168        size += size_of_val(&self.deleted_offset);
1169        size
1170    }
1171    pub fn new(
1172        build_side: JoinSide,
1173        on: Vec<PhysicalExprRef>,
1174        schema: SchemaRef,
1175    ) -> Self {
1176        Self {
1177            build_side,
1178            input_buffer: RecordBatch::new_empty(schema),
1179            on,
1180            hashmap: PruningJoinHashMap::with_capacity(0),
1181            hashes_buffer: vec![],
1182            visited_rows: HashSet::new(),
1183            offset: 0,
1184            deleted_offset: 0,
1185        }
1186    }
1187
1188    /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch.
1189    ///
1190    /// # Arguments
1191    ///
1192    /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer
1193    /// * `random_state` - The random state used to hash values
1194    ///
1195    /// # Returns
1196    ///
1197    /// Returns a [Result] encapsulating any intermediate errors.
1198    pub(crate) fn update_internal_state(
1199        &mut self,
1200        batch: &RecordBatch,
1201        random_state: &RandomState,
1202    ) -> Result<()> {
1203        // Merge the incoming batch with the existing input buffer:
1204        self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
1205        // Resize the hashes buffer to the number of rows in the incoming batch:
1206        self.hashes_buffer.resize(batch.num_rows(), 0);
1207        // Get allocation_info before adding the item
1208        // Update the hashmap with the join key values and hashes of the incoming batch:
1209        update_hash(
1210            &self.on,
1211            batch,
1212            &mut self.hashmap,
1213            self.offset,
1214            random_state,
1215            &mut self.hashes_buffer,
1216            self.deleted_offset,
1217            false,
1218        )?;
1219        Ok(())
1220    }
1221
1222    /// Calculate prune length.
1223    ///
1224    /// # Arguments
1225    ///
1226    /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression..
1227    /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression.
1228    /// * `graph` - A mutable reference to the physical expression graph.
1229    ///
1230    /// # Returns
1231    ///
1232    /// A Result object that contains the pruning length.
1233    pub(crate) fn calculate_prune_length_with_probe_batch(
1234        &mut self,
1235        build_side_sorted_filter_expr: &mut SortedFilterExpr,
1236        probe_side_sorted_filter_expr: &mut SortedFilterExpr,
1237        graph: &mut ExprIntervalGraph,
1238    ) -> Result<usize> {
1239        // Return early if the input buffer is empty:
1240        if self.input_buffer.num_rows() == 0 {
1241            return Ok(0);
1242        }
1243        // Process the build and probe side sorted filter expressions if both are present:
1244        // Collect the sorted filter expressions into a vector of (node_index, interval) tuples:
1245        let mut filter_intervals = vec![];
1246        for expr in [
1247            &build_side_sorted_filter_expr,
1248            &probe_side_sorted_filter_expr,
1249        ] {
1250            filter_intervals.push((expr.node_index(), expr.interval().clone()))
1251        }
1252        // Update the physical expression graph using the join filter intervals:
1253        graph.update_ranges(&mut filter_intervals, Interval::TRUE)?;
1254        // Extract the new join filter interval for the build side:
1255        let calculated_build_side_interval = filter_intervals.remove(0).1;
1256        // If the intervals have not changed, return early without pruning:
1257        if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
1258            return Ok(0);
1259        }
1260        // Update the build side interval and determine the pruning length:
1261        build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
1262
1263        determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
1264    }
1265
1266    pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
1267        // Prune the hash values:
1268        self.hashmap.prune_hash_values(
1269            prune_length,
1270            self.deleted_offset as u64,
1271            HASHMAP_SHRINK_SCALE_FACTOR,
1272        );
1273        // Remove pruned rows from the visited rows set:
1274        for row in self.deleted_offset..(self.deleted_offset + prune_length) {
1275            self.visited_rows.remove(&row);
1276        }
1277        // Update the input buffer after pruning:
1278        self.input_buffer = self
1279            .input_buffer
1280            .slice(prune_length, self.input_buffer.num_rows() - prune_length);
1281        // Increment the deleted offset:
1282        self.deleted_offset += prune_length;
1283        Ok(())
1284    }
1285}
1286
1287/// `SymmetricHashJoinStream` manages incremental join operations between two
1288/// streams. Unlike traditional join approaches that need to scan one side of
1289/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates
1290/// more dynamic join operations by working with streams as they emit data. This
1291/// approach allows for more efficient processing, particularly in scenarios
1292/// where waiting for complete data materialization is not feasible or optimal.
1293/// The trait provides a framework for handling various states of such a join
1294/// process, ensuring that join logic is efficiently executed as data becomes
1295/// available from either stream.
1296///
1297/// This implementation performs eager joins of data from two different asynchronous
1298/// streams, typically referred to as left and right streams. The implementation
1299/// provides a comprehensive set of methods to control and execute the join
1300/// process, leveraging the states defined in `SHJStreamState`. Methods are
1301/// primarily focused on asynchronously fetching data batches from each stream,
1302/// processing them, and managing transitions between various states of the join.
1303///
1304/// This implementations use a state machine approach to navigate different
1305/// stages of the join operation, handling data from both streams and determining
1306/// when the join completes.
1307///
1308/// State Transitions:
1309/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
1310///   - In `fetch_next_from_left_stream`, when fetching a batch from the left stream:
1311///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
1312///       processing the batch.
1313///     - On error (`Some(Err(e))`), the error is returned, and the state remains
1314///       unchanged.
1315///     - On no data (`None`), state changes to `LeftExhausted`, returning `Continue`
1316///       to proceed with the join process.
1317/// - From `PullRight` to `PullLeft` or `RightExhausted`:
1318///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
1319///     - If a batch is available, state changes to `PullLeft` for processing.
1320///     - On error, the error is returned without changing the state.
1321///     - If right stream is exhausted (`None`), state transitions to `RightExhausted`,
1322///       with a `Continue` result.
1323/// - Handling `RightExhausted` and `LeftExhausted`:
1324///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios
1325///     when streams are exhausted:
1326///     - They attempt to continue processing with the other stream.
1327///     - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`.
1328/// - Transition to `BothExhausted { final_result: true }`:
1329///   - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are
1330///     exhausted, indicating completion of processing and availability of final results.
1331impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
1332    /// Implements the main polling logic for the join stream.
1333    ///
1334    /// This method continuously checks the state of the join stream and
1335    /// acts accordingly by delegating the handling to appropriate sub-methods
1336    /// depending on the current state.
1337    ///
1338    /// # Arguments
1339    ///
1340    /// * `cx` - A context that facilitates cooperative non-blocking execution within a task.
1341    ///
1342    /// # Returns
1343    ///
1344    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a `RecordBatch` or None.
1345    fn poll_next_impl(
1346        &mut self,
1347        cx: &mut Context<'_>,
1348    ) -> Poll<Option<Result<RecordBatch>>> {
1349        loop {
1350            match self.batch_transformer.next() {
1351                None => {
1352                    let result = match self.state() {
1353                        SHJStreamState::PullRight => {
1354                            ready!(self.fetch_next_from_right_stream(cx))
1355                        }
1356                        SHJStreamState::PullLeft => {
1357                            ready!(self.fetch_next_from_left_stream(cx))
1358                        }
1359                        SHJStreamState::RightExhausted => {
1360                            ready!(self.handle_right_stream_end(cx))
1361                        }
1362                        SHJStreamState::LeftExhausted => {
1363                            ready!(self.handle_left_stream_end(cx))
1364                        }
1365                        SHJStreamState::BothExhausted {
1366                            final_result: false,
1367                        } => self.prepare_for_final_results_after_exhaustion(),
1368                        SHJStreamState::BothExhausted { final_result: true } => {
1369                            return Poll::Ready(None);
1370                        }
1371                    };
1372
1373                    match result? {
1374                        StatefulStreamResult::Ready(None) => {
1375                            return Poll::Ready(None);
1376                        }
1377                        StatefulStreamResult::Ready(Some(batch)) => {
1378                            self.batch_transformer.set_batch(batch);
1379                        }
1380                        _ => {}
1381                    }
1382                }
1383                Some((batch, _)) => {
1384                    return self
1385                        .metrics
1386                        .baseline_metrics
1387                        .record_poll(Poll::Ready(Some(Ok(batch))));
1388                }
1389            }
1390        }
1391    }
1392
1393    /// Release the right input pipeline's resources.
1394    fn cleanup_depleted_right_stream(&mut self) {
1395        let right_schema = self.right_stream.schema();
1396        self.right_stream = Box::pin(EmptyRecordBatchStream::new(right_schema));
1397    }
1398
1399    /// Release the left input pipeline's resources.
1400    fn cleanup_depleted_left_stream(&mut self) {
1401        let left_schema = self.left_stream.schema();
1402        self.left_stream = Box::pin(EmptyRecordBatchStream::new(left_schema));
1403    }
1404
1405    /// Asynchronously pulls the next batch from the right stream.
1406    ///
1407    /// This default implementation checks for the next value in the right stream.
1408    /// If a batch is found, the state is switched to `PullLeft`, and the batch handling
1409    /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`.
1410    ///
1411    /// # Returns
1412    ///
1413    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1414    fn fetch_next_from_right_stream(
1415        &mut self,
1416        cx: &mut Context<'_>,
1417    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1418        match ready!(self.right_stream().poll_next_unpin(cx)) {
1419            Some(Ok(batch)) => {
1420                if batch.num_rows() == 0 {
1421                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1422                }
1423                self.set_state(SHJStreamState::PullLeft);
1424                Poll::Ready(self.process_batch_from_right(&batch))
1425            }
1426            Some(Err(e)) => Poll::Ready(Err(e)),
1427            None => {
1428                self.cleanup_depleted_right_stream();
1429                self.set_state(SHJStreamState::RightExhausted);
1430                Poll::Ready(Ok(StatefulStreamResult::Continue))
1431            }
1432        }
1433    }
1434
1435    /// Asynchronously pulls the next batch from the left stream.
1436    ///
1437    /// This default implementation checks for the next value in the left stream.
1438    /// If a batch is found, the state is switched to `PullRight`, and the batch handling
1439    /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`.
1440    ///
1441    /// # Returns
1442    ///
1443    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1444    fn fetch_next_from_left_stream(
1445        &mut self,
1446        cx: &mut Context<'_>,
1447    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1448        match ready!(self.left_stream().poll_next_unpin(cx)) {
1449            Some(Ok(batch)) => {
1450                if batch.num_rows() == 0 {
1451                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1452                }
1453                self.set_state(SHJStreamState::PullRight);
1454                Poll::Ready(self.process_batch_from_left(&batch))
1455            }
1456            Some(Err(e)) => Poll::Ready(Err(e)),
1457            None => {
1458                self.cleanup_depleted_left_stream();
1459                self.set_state(SHJStreamState::LeftExhausted);
1460                Poll::Ready(Ok(StatefulStreamResult::Continue))
1461            }
1462        }
1463    }
1464
1465    /// Asynchronously handles the scenario when the right stream is exhausted.
1466    ///
1467    /// In this default implementation, when the right stream is exhausted, it attempts
1468    /// to pull from the left stream. If a batch is found in the left stream, it delegates
1469    /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set
1470    /// to indicate both streams are exhausted without final results yet.
1471    ///
1472    /// # Returns
1473    ///
1474    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1475    fn handle_right_stream_end(
1476        &mut self,
1477        cx: &mut Context<'_>,
1478    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1479        match ready!(self.left_stream().poll_next_unpin(cx)) {
1480            Some(Ok(batch)) => {
1481                if batch.num_rows() == 0 {
1482                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1483                }
1484                Poll::Ready(self.process_batch_after_right_end(&batch))
1485            }
1486            Some(Err(e)) => Poll::Ready(Err(e)),
1487            None => {
1488                self.cleanup_depleted_left_stream();
1489                self.set_state(SHJStreamState::BothExhausted {
1490                    final_result: false,
1491                });
1492                Poll::Ready(Ok(StatefulStreamResult::Continue))
1493            }
1494        }
1495    }
1496
1497    /// Asynchronously handles the scenario when the left stream is exhausted.
1498    ///
1499    /// When the left stream is exhausted, this default
1500    /// implementation tries to pull from the right stream and delegates the batch
1501    /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state
1502    /// is updated to indicate so.
1503    ///
1504    /// # Returns
1505    ///
1506    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1507    fn handle_left_stream_end(
1508        &mut self,
1509        cx: &mut Context<'_>,
1510    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1511        match ready!(self.right_stream().poll_next_unpin(cx)) {
1512            Some(Ok(batch)) => {
1513                if batch.num_rows() == 0 {
1514                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1515                }
1516                Poll::Ready(self.process_batch_after_left_end(&batch))
1517            }
1518            Some(Err(e)) => Poll::Ready(Err(e)),
1519            None => {
1520                self.cleanup_depleted_right_stream();
1521                self.set_state(SHJStreamState::BothExhausted {
1522                    final_result: false,
1523                });
1524                Poll::Ready(Ok(StatefulStreamResult::Continue))
1525            }
1526        }
1527    }
1528
1529    /// Handles the state when both streams are exhausted and final results are yet to be produced.
1530    ///
1531    /// This default implementation switches the state to indicate both streams are
1532    /// exhausted with final results and then invokes the handling for this specific
1533    /// scenario via `process_batches_before_finalization`.
1534    ///
1535    /// # Returns
1536    ///
1537    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
1538    fn prepare_for_final_results_after_exhaustion(
1539        &mut self,
1540    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1541        self.set_state(SHJStreamState::BothExhausted { final_result: true });
1542        self.process_batches_before_finalization()
1543    }
1544
1545    fn process_batch_from_right(
1546        &mut self,
1547        batch: &RecordBatch,
1548    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1549        self.perform_join_for_given_side(batch, JoinSide::Right)
1550            .map(|maybe_batch| {
1551                if maybe_batch.is_some() {
1552                    StatefulStreamResult::Ready(maybe_batch)
1553                } else {
1554                    StatefulStreamResult::Continue
1555                }
1556            })
1557    }
1558
1559    fn process_batch_from_left(
1560        &mut self,
1561        batch: &RecordBatch,
1562    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1563        self.perform_join_for_given_side(batch, JoinSide::Left)
1564            .map(|maybe_batch| {
1565                if maybe_batch.is_some() {
1566                    StatefulStreamResult::Ready(maybe_batch)
1567                } else {
1568                    StatefulStreamResult::Continue
1569                }
1570            })
1571    }
1572
1573    fn process_batch_after_left_end(
1574        &mut self,
1575        right_batch: &RecordBatch,
1576    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1577        self.process_batch_from_right(right_batch)
1578    }
1579
1580    fn process_batch_after_right_end(
1581        &mut self,
1582        left_batch: &RecordBatch,
1583    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1584        self.process_batch_from_left(left_batch)
1585    }
1586
1587    fn process_batches_before_finalization(
1588        &mut self,
1589    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1590        // Get the left side results:
1591        let left_result = build_side_determined_results(
1592            &self.left,
1593            &self.schema,
1594            self.left.input_buffer.num_rows(),
1595            self.right.input_buffer.schema(),
1596            self.join_type,
1597            &self.column_indices,
1598        )?;
1599        // Get the right side results:
1600        let right_result = build_side_determined_results(
1601            &self.right,
1602            &self.schema,
1603            self.right.input_buffer.num_rows(),
1604            self.left.input_buffer.schema(),
1605            self.join_type,
1606            &self.column_indices,
1607        )?;
1608
1609        // Combine the left and right results:
1610        let result = combine_two_batches(&self.schema, left_result, right_result)?;
1611
1612        // Return the result:
1613        if result.is_some() {
1614            return Ok(StatefulStreamResult::Ready(result));
1615        }
1616        Ok(StatefulStreamResult::Continue)
1617    }
1618
1619    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
1620        &mut self.right_stream
1621    }
1622
1623    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
1624        &mut self.left_stream
1625    }
1626
1627    fn set_state(&mut self, state: SHJStreamState) {
1628        self.state = state;
1629    }
1630
1631    fn state(&mut self) -> SHJStreamState {
1632        self.state.clone()
1633    }
1634
1635    fn size(&self) -> usize {
1636        let mut size = 0;
1637        size += size_of_val(&self.schema);
1638        size += size_of_val(&self.filter);
1639        size += size_of_val(&self.join_type);
1640        size += self.left.size();
1641        size += self.right.size();
1642        size += size_of_val(&self.column_indices);
1643        size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
1644        size += size_of_val(&self.left_sorted_filter_expr);
1645        size += size_of_val(&self.right_sorted_filter_expr);
1646        size += size_of_val(&self.random_state);
1647        size += size_of_val(&self.null_equality);
1648        size += size_of_val(&self.metrics);
1649        size
1650    }
1651
1652    /// Performs a join operation for the specified `probe_side` (either left or right).
1653    /// This function:
1654    /// 1. Determines which side is the probe and which is the build side.
1655    /// 2. Updates metrics based on the batch that was polled.
1656    /// 3. Executes the join with the given `probe_batch`.
1657    /// 4. Optionally computes anti-join results if all conditions are met.
1658    /// 5. Combines the results and returns a combined batch or `None` if no batch was produced.
1659    fn perform_join_for_given_side(
1660        &mut self,
1661        probe_batch: &RecordBatch,
1662        probe_side: JoinSide,
1663    ) -> Result<Option<RecordBatch>> {
1664        let (
1665            probe_hash_joiner,
1666            build_hash_joiner,
1667            probe_side_sorted_filter_expr,
1668            build_side_sorted_filter_expr,
1669            probe_side_metrics,
1670        ) = if probe_side.eq(&JoinSide::Left) {
1671            (
1672                &mut self.left,
1673                &mut self.right,
1674                &mut self.left_sorted_filter_expr,
1675                &mut self.right_sorted_filter_expr,
1676                &mut self.metrics.left,
1677            )
1678        } else {
1679            (
1680                &mut self.right,
1681                &mut self.left,
1682                &mut self.right_sorted_filter_expr,
1683                &mut self.left_sorted_filter_expr,
1684                &mut self.metrics.right,
1685            )
1686        };
1687        // Update the metrics for the stream that was polled:
1688        probe_side_metrics.input_batches.add(1);
1689        probe_side_metrics.input_rows.add(probe_batch.num_rows());
1690        // Update the internal state of the hash joiner for the build side:
1691        probe_hash_joiner.update_internal_state(probe_batch, &self.random_state)?;
1692        // Join the two sides:
1693        let equal_result = join_with_probe_batch(
1694            build_hash_joiner,
1695            probe_hash_joiner,
1696            &self.schema,
1697            self.join_type,
1698            self.filter.as_ref(),
1699            probe_batch,
1700            &self.column_indices,
1701            &self.random_state,
1702            self.null_equality,
1703        )?;
1704        // Increment the offset for the probe hash joiner:
1705        probe_hash_joiner.offset += probe_batch.num_rows();
1706
1707        let anti_result = if let (
1708            Some(build_side_sorted_filter_expr),
1709            Some(probe_side_sorted_filter_expr),
1710            Some(graph),
1711        ) = (
1712            build_side_sorted_filter_expr.as_mut(),
1713            probe_side_sorted_filter_expr.as_mut(),
1714            self.graph.as_mut(),
1715        ) {
1716            // Calculate filter intervals:
1717            calculate_filter_expr_intervals(
1718                &build_hash_joiner.input_buffer,
1719                build_side_sorted_filter_expr,
1720                probe_batch,
1721                probe_side_sorted_filter_expr,
1722            )?;
1723            let prune_length = build_hash_joiner
1724                .calculate_prune_length_with_probe_batch(
1725                    build_side_sorted_filter_expr,
1726                    probe_side_sorted_filter_expr,
1727                    graph,
1728                )?;
1729            let result = build_side_determined_results(
1730                build_hash_joiner,
1731                &self.schema,
1732                prune_length,
1733                probe_batch.schema(),
1734                self.join_type,
1735                &self.column_indices,
1736            )?;
1737            build_hash_joiner.prune_internal_state(prune_length)?;
1738            result
1739        } else {
1740            None
1741        };
1742
1743        // Combine results:
1744        let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
1745        let capacity = self.size();
1746        self.metrics.stream_memory_usage.set(capacity);
1747        self.reservation.try_resize(capacity)?;
1748        Ok(result)
1749    }
1750}
1751
1752/// Represents the various states of an symmetric hash join stream operation.
1753///
1754/// This enum is used to track the current state of streaming during a join
1755/// operation. It provides indicators as to which side of the join needs to be
1756/// pulled next or if one (or both) sides have been exhausted. This allows
1757/// for efficient management of resources and optimal performance during the
1758/// join process.
1759#[derive(Clone, Debug)]
1760pub enum SHJStreamState {
1761    /// Indicates that the next step should pull from the right side of the join.
1762    PullRight,
1763
1764    /// Indicates that the next step should pull from the left side of the join.
1765    PullLeft,
1766
1767    /// State representing that the right side of the join has been fully processed.
1768    RightExhausted,
1769
1770    /// State representing that the left side of the join has been fully processed.
1771    LeftExhausted,
1772
1773    /// Represents a state where both sides of the join are exhausted.
1774    ///
1775    /// The `final_result` field indicates whether the join operation has
1776    /// produced a final result or not.
1777    BothExhausted { final_result: bool },
1778}
1779
1780#[cfg(test)]
1781mod tests {
1782    use std::collections::HashMap;
1783    use std::sync::{LazyLock, Mutex};
1784
1785    use super::*;
1786    use crate::joins::test_utils::{
1787        build_sides_record_batches, compare_batches, complicated_filter,
1788        create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
1789        join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
1790        partitioned_sym_join_with_filter, split_record_batches,
1791    };
1792
1793    use arrow::compute::SortOptions;
1794    use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
1795    use datafusion_common::ScalarValue;
1796    use datafusion_execution::config::SessionConfig;
1797    use datafusion_expr::Operator;
1798    use datafusion_physical_expr::expressions::{Column, binary, col, lit};
1799    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
1800
1801    use rstest::*;
1802
1803    const TABLE_SIZE: i32 = 30;
1804
1805    type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size)
1806    type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); // (left, right)
1807
1808    // Cache for storing tables
1809    static TABLE_CACHE: LazyLock<Mutex<HashMap<TableKey, TableValue>>> =
1810        LazyLock::new(|| Mutex::new(HashMap::new()));
1811
1812    fn get_or_create_table(
1813        cardinality: (i32, i32),
1814        batch_size: usize,
1815    ) -> Result<TableValue> {
1816        {
1817            let cache = TABLE_CACHE.lock().unwrap();
1818            if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
1819                return Ok(table.clone());
1820            }
1821        }
1822
1823        // If not, create the table
1824        let (left_batch, right_batch) =
1825            build_sides_record_batches(TABLE_SIZE, cardinality)?;
1826
1827        let (left_partition, right_partition) = (
1828            split_record_batches(&left_batch, batch_size)?,
1829            split_record_batches(&right_batch, batch_size)?,
1830        );
1831
1832        // Lock the cache again and store the table
1833        let mut cache = TABLE_CACHE.lock().unwrap();
1834
1835        // Store the table in the cache
1836        cache.insert(
1837            (cardinality.0, cardinality.1, batch_size),
1838            (left_partition.clone(), right_partition.clone()),
1839        );
1840
1841        Ok((left_partition, right_partition))
1842    }
1843
1844    pub async fn experiment(
1845        left: Arc<dyn ExecutionPlan>,
1846        right: Arc<dyn ExecutionPlan>,
1847        filter: Option<JoinFilter>,
1848        join_type: JoinType,
1849        on: JoinOn,
1850        task_ctx: Arc<TaskContext>,
1851    ) -> Result<()> {
1852        let first_batches = partitioned_sym_join_with_filter(
1853            Arc::clone(&left),
1854            Arc::clone(&right),
1855            on.clone(),
1856            filter.clone(),
1857            &join_type,
1858            NullEquality::NullEqualsNothing,
1859            Arc::clone(&task_ctx),
1860        )
1861        .await?;
1862        let second_batches = partitioned_hash_join_with_filter(
1863            left,
1864            right,
1865            on,
1866            filter,
1867            &join_type,
1868            NullEquality::NullEqualsNothing,
1869            task_ctx,
1870        )
1871        .await?;
1872        compare_batches(&first_batches, &second_batches);
1873        Ok(())
1874    }
1875
1876    #[rstest]
1877    #[tokio::test(flavor = "multi_thread")]
1878    async fn complex_join_all_one_ascending_numeric(
1879        #[values(
1880            JoinType::Inner,
1881            JoinType::Left,
1882            JoinType::Right,
1883            JoinType::RightSemi,
1884            JoinType::LeftSemi,
1885            JoinType::LeftAnti,
1886            JoinType::LeftMark,
1887            JoinType::RightAnti,
1888            JoinType::RightMark,
1889            JoinType::Full
1890        )]
1891        join_type: JoinType,
1892        #[values(
1893        (4, 5),
1894        (12, 17),
1895        )]
1896        cardinality: (i32, i32),
1897    ) -> Result<()> {
1898        // a + b > c + 10 AND a + b < c + 100
1899        let task_ctx = Arc::new(TaskContext::default());
1900
1901        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
1902
1903        let left_schema = &left_partition[0].schema();
1904        let right_schema = &right_partition[0].schema();
1905
1906        let left_sorted = [PhysicalSortExpr {
1907            expr: binary(
1908                col("la1", left_schema)?,
1909                Operator::Plus,
1910                col("la2", left_schema)?,
1911                left_schema,
1912            )?,
1913            options: SortOptions::default(),
1914        }]
1915        .into();
1916        let right_sorted = [PhysicalSortExpr {
1917            expr: col("ra1", right_schema)?,
1918            options: SortOptions::default(),
1919        }]
1920        .into();
1921        let (left, right) = create_memory_table(
1922            left_partition,
1923            right_partition,
1924            vec![left_sorted],
1925            vec![right_sorted],
1926        )?;
1927
1928        let on = vec![(
1929            binary(
1930                col("lc1", left_schema)?,
1931                Operator::Plus,
1932                lit(ScalarValue::Int32(Some(1))),
1933                left_schema,
1934            )?,
1935            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1936        )];
1937
1938        let intermediate_schema = Schema::new(vec![
1939            Field::new("0", DataType::Int32, true),
1940            Field::new("1", DataType::Int32, true),
1941            Field::new("2", DataType::Int32, true),
1942        ]);
1943        let filter_expr = complicated_filter(&intermediate_schema)?;
1944        let column_indices = vec![
1945            ColumnIndex {
1946                index: left_schema.index_of("la1")?,
1947                side: JoinSide::Left,
1948            },
1949            ColumnIndex {
1950                index: left_schema.index_of("la2")?,
1951                side: JoinSide::Left,
1952            },
1953            ColumnIndex {
1954                index: right_schema.index_of("ra1")?,
1955                side: JoinSide::Right,
1956            },
1957        ];
1958        let filter =
1959            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1960
1961        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1962        Ok(())
1963    }
1964
1965    #[rstest]
1966    #[tokio::test(flavor = "multi_thread")]
1967    async fn join_all_one_ascending_numeric(
1968        #[values(
1969            JoinType::Inner,
1970            JoinType::Left,
1971            JoinType::Right,
1972            JoinType::RightSemi,
1973            JoinType::LeftSemi,
1974            JoinType::LeftAnti,
1975            JoinType::LeftMark,
1976            JoinType::RightAnti,
1977            JoinType::RightMark,
1978            JoinType::Full
1979        )]
1980        join_type: JoinType,
1981        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1982    ) -> Result<()> {
1983        let task_ctx = Arc::new(TaskContext::default());
1984        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1985
1986        let left_schema = &left_partition[0].schema();
1987        let right_schema = &right_partition[0].schema();
1988
1989        let left_sorted = [PhysicalSortExpr {
1990            expr: col("la1", left_schema)?,
1991            options: SortOptions::default(),
1992        }]
1993        .into();
1994        let right_sorted = [PhysicalSortExpr {
1995            expr: col("ra1", right_schema)?,
1996            options: SortOptions::default(),
1997        }]
1998        .into();
1999        let (left, right) = create_memory_table(
2000            left_partition,
2001            right_partition,
2002            vec![left_sorted],
2003            vec![right_sorted],
2004        )?;
2005
2006        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2007
2008        let intermediate_schema = Schema::new(vec![
2009            Field::new("left", DataType::Int32, true),
2010            Field::new("right", DataType::Int32, true),
2011        ]);
2012        let filter_expr = join_expr_tests_fixture_i32(
2013            case_expr,
2014            col("left", &intermediate_schema)?,
2015            col("right", &intermediate_schema)?,
2016        );
2017        let column_indices = vec![
2018            ColumnIndex {
2019                index: 0,
2020                side: JoinSide::Left,
2021            },
2022            ColumnIndex {
2023                index: 0,
2024                side: JoinSide::Right,
2025            },
2026        ];
2027        let filter =
2028            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2029
2030        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2031        Ok(())
2032    }
2033
2034    #[rstest]
2035    #[tokio::test(flavor = "multi_thread")]
2036    async fn join_without_sort_information(
2037        #[values(
2038            JoinType::Inner,
2039            JoinType::Left,
2040            JoinType::Right,
2041            JoinType::RightSemi,
2042            JoinType::LeftSemi,
2043            JoinType::LeftAnti,
2044            JoinType::LeftMark,
2045            JoinType::RightAnti,
2046            JoinType::RightMark,
2047            JoinType::Full
2048        )]
2049        join_type: JoinType,
2050        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2051    ) -> Result<()> {
2052        let task_ctx = Arc::new(TaskContext::default());
2053        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
2054
2055        let left_schema = &left_partition[0].schema();
2056        let right_schema = &right_partition[0].schema();
2057        let (left, right) =
2058            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2059
2060        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2061
2062        let intermediate_schema = Schema::new(vec![
2063            Field::new("left", DataType::Int32, true),
2064            Field::new("right", DataType::Int32, true),
2065        ]);
2066        let filter_expr = join_expr_tests_fixture_i32(
2067            case_expr,
2068            col("left", &intermediate_schema)?,
2069            col("right", &intermediate_schema)?,
2070        );
2071        let column_indices = vec![
2072            ColumnIndex {
2073                index: 5,
2074                side: JoinSide::Left,
2075            },
2076            ColumnIndex {
2077                index: 5,
2078                side: JoinSide::Right,
2079            },
2080        ];
2081        let filter =
2082            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2083
2084        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2085        Ok(())
2086    }
2087
2088    #[rstest]
2089    #[tokio::test(flavor = "multi_thread")]
2090    async fn join_without_filter(
2091        #[values(
2092            JoinType::Inner,
2093            JoinType::Left,
2094            JoinType::Right,
2095            JoinType::RightSemi,
2096            JoinType::LeftSemi,
2097            JoinType::LeftAnti,
2098            JoinType::LeftMark,
2099            JoinType::RightAnti,
2100            JoinType::RightMark,
2101            JoinType::Full
2102        )]
2103        join_type: JoinType,
2104    ) -> Result<()> {
2105        let task_ctx = Arc::new(TaskContext::default());
2106        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2107        let left_schema = &left_partition[0].schema();
2108        let right_schema = &right_partition[0].schema();
2109        let (left, right) =
2110            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2111
2112        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2113        experiment(left, right, None, join_type, on, task_ctx).await?;
2114        Ok(())
2115    }
2116
2117    #[rstest]
2118    #[tokio::test(flavor = "multi_thread")]
2119    async fn join_all_one_descending_numeric_particular(
2120        #[values(
2121            JoinType::Inner,
2122            JoinType::Left,
2123            JoinType::Right,
2124            JoinType::RightSemi,
2125            JoinType::LeftSemi,
2126            JoinType::LeftAnti,
2127            JoinType::LeftMark,
2128            JoinType::RightAnti,
2129            JoinType::RightMark,
2130            JoinType::Full
2131        )]
2132        join_type: JoinType,
2133        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2134    ) -> Result<()> {
2135        let task_ctx = Arc::new(TaskContext::default());
2136        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2137        let left_schema = &left_partition[0].schema();
2138        let right_schema = &right_partition[0].schema();
2139        let left_sorted = [PhysicalSortExpr {
2140            expr: col("la1_des", left_schema)?,
2141            options: SortOptions {
2142                descending: true,
2143                nulls_first: true,
2144            },
2145        }]
2146        .into();
2147        let right_sorted = [PhysicalSortExpr {
2148            expr: col("ra1_des", right_schema)?,
2149            options: SortOptions {
2150                descending: true,
2151                nulls_first: true,
2152            },
2153        }]
2154        .into();
2155        let (left, right) = create_memory_table(
2156            left_partition,
2157            right_partition,
2158            vec![left_sorted],
2159            vec![right_sorted],
2160        )?;
2161
2162        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2163
2164        let intermediate_schema = Schema::new(vec![
2165            Field::new("left", DataType::Int32, true),
2166            Field::new("right", DataType::Int32, true),
2167        ]);
2168        let filter_expr = join_expr_tests_fixture_i32(
2169            case_expr,
2170            col("left", &intermediate_schema)?,
2171            col("right", &intermediate_schema)?,
2172        );
2173        let column_indices = vec![
2174            ColumnIndex {
2175                index: 5,
2176                side: JoinSide::Left,
2177            },
2178            ColumnIndex {
2179                index: 5,
2180                side: JoinSide::Right,
2181            },
2182        ];
2183        let filter =
2184            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2185
2186        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2187        Ok(())
2188    }
2189
2190    #[tokio::test(flavor = "multi_thread")]
2191    async fn build_null_columns_first() -> Result<()> {
2192        let join_type = JoinType::Full;
2193        let case_expr = 1;
2194        let session_config = SessionConfig::new().with_repartition_joins(false);
2195        let task_ctx = TaskContext::default().with_session_config(session_config);
2196        let task_ctx = Arc::new(task_ctx);
2197        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2198        let left_schema = &left_partition[0].schema();
2199        let right_schema = &right_partition[0].schema();
2200        let left_sorted = [PhysicalSortExpr {
2201            expr: col("l_asc_null_first", left_schema)?,
2202            options: SortOptions {
2203                descending: false,
2204                nulls_first: true,
2205            },
2206        }]
2207        .into();
2208        let right_sorted = [PhysicalSortExpr {
2209            expr: col("r_asc_null_first", right_schema)?,
2210            options: SortOptions {
2211                descending: false,
2212                nulls_first: true,
2213            },
2214        }]
2215        .into();
2216        let (left, right) = create_memory_table(
2217            left_partition,
2218            right_partition,
2219            vec![left_sorted],
2220            vec![right_sorted],
2221        )?;
2222
2223        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2224
2225        let intermediate_schema = Schema::new(vec![
2226            Field::new("left", DataType::Int32, true),
2227            Field::new("right", DataType::Int32, true),
2228        ]);
2229        let filter_expr = join_expr_tests_fixture_i32(
2230            case_expr,
2231            col("left", &intermediate_schema)?,
2232            col("right", &intermediate_schema)?,
2233        );
2234        let column_indices = vec![
2235            ColumnIndex {
2236                index: 6,
2237                side: JoinSide::Left,
2238            },
2239            ColumnIndex {
2240                index: 6,
2241                side: JoinSide::Right,
2242            },
2243        ];
2244        let filter =
2245            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2246        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2247        Ok(())
2248    }
2249
2250    #[tokio::test(flavor = "multi_thread")]
2251    async fn build_null_columns_last() -> Result<()> {
2252        let join_type = JoinType::Full;
2253        let case_expr = 1;
2254        let session_config = SessionConfig::new().with_repartition_joins(false);
2255        let task_ctx = TaskContext::default().with_session_config(session_config);
2256        let task_ctx = Arc::new(task_ctx);
2257        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2258
2259        let left_schema = &left_partition[0].schema();
2260        let right_schema = &right_partition[0].schema();
2261        let left_sorted = [PhysicalSortExpr {
2262            expr: col("l_asc_null_last", left_schema)?,
2263            options: SortOptions {
2264                descending: false,
2265                nulls_first: false,
2266            },
2267        }]
2268        .into();
2269        let right_sorted = [PhysicalSortExpr {
2270            expr: col("r_asc_null_last", right_schema)?,
2271            options: SortOptions {
2272                descending: false,
2273                nulls_first: false,
2274            },
2275        }]
2276        .into();
2277        let (left, right) = create_memory_table(
2278            left_partition,
2279            right_partition,
2280            vec![left_sorted],
2281            vec![right_sorted],
2282        )?;
2283
2284        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2285
2286        let intermediate_schema = Schema::new(vec![
2287            Field::new("left", DataType::Int32, true),
2288            Field::new("right", DataType::Int32, true),
2289        ]);
2290        let filter_expr = join_expr_tests_fixture_i32(
2291            case_expr,
2292            col("left", &intermediate_schema)?,
2293            col("right", &intermediate_schema)?,
2294        );
2295        let column_indices = vec![
2296            ColumnIndex {
2297                index: 7,
2298                side: JoinSide::Left,
2299            },
2300            ColumnIndex {
2301                index: 7,
2302                side: JoinSide::Right,
2303            },
2304        ];
2305        let filter =
2306            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2307
2308        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2309        Ok(())
2310    }
2311
2312    #[tokio::test(flavor = "multi_thread")]
2313    async fn build_null_columns_first_descending() -> Result<()> {
2314        let join_type = JoinType::Full;
2315        let cardinality = (10, 11);
2316        let case_expr = 1;
2317        let session_config = SessionConfig::new().with_repartition_joins(false);
2318        let task_ctx = TaskContext::default().with_session_config(session_config);
2319        let task_ctx = Arc::new(task_ctx);
2320        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2321
2322        let left_schema = &left_partition[0].schema();
2323        let right_schema = &right_partition[0].schema();
2324        let left_sorted = [PhysicalSortExpr {
2325            expr: col("l_desc_null_first", left_schema)?,
2326            options: SortOptions {
2327                descending: true,
2328                nulls_first: true,
2329            },
2330        }]
2331        .into();
2332        let right_sorted = [PhysicalSortExpr {
2333            expr: col("r_desc_null_first", right_schema)?,
2334            options: SortOptions {
2335                descending: true,
2336                nulls_first: true,
2337            },
2338        }]
2339        .into();
2340        let (left, right) = create_memory_table(
2341            left_partition,
2342            right_partition,
2343            vec![left_sorted],
2344            vec![right_sorted],
2345        )?;
2346
2347        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2348
2349        let intermediate_schema = Schema::new(vec![
2350            Field::new("left", DataType::Int32, true),
2351            Field::new("right", DataType::Int32, true),
2352        ]);
2353        let filter_expr = join_expr_tests_fixture_i32(
2354            case_expr,
2355            col("left", &intermediate_schema)?,
2356            col("right", &intermediate_schema)?,
2357        );
2358        let column_indices = vec![
2359            ColumnIndex {
2360                index: 8,
2361                side: JoinSide::Left,
2362            },
2363            ColumnIndex {
2364                index: 8,
2365                side: JoinSide::Right,
2366            },
2367        ];
2368        let filter =
2369            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2370
2371        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2372        Ok(())
2373    }
2374
2375    #[tokio::test(flavor = "multi_thread")]
2376    async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
2377        let cardinality = (3, 4);
2378        let join_type = JoinType::Full;
2379
2380        // a + b > c + 10 AND a + b < c + 100
2381        let session_config = SessionConfig::new().with_repartition_joins(false);
2382        let task_ctx = TaskContext::default().with_session_config(session_config);
2383        let task_ctx = Arc::new(task_ctx);
2384        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2385
2386        let left_schema = &left_partition[0].schema();
2387        let right_schema = &right_partition[0].schema();
2388        let left_sorted = [PhysicalSortExpr {
2389            expr: col("la1", left_schema)?,
2390            options: SortOptions::default(),
2391        }]
2392        .into();
2393        let right_sorted = [PhysicalSortExpr {
2394            expr: col("ra1", right_schema)?,
2395            options: SortOptions::default(),
2396        }]
2397        .into();
2398        let (left, right) = create_memory_table(
2399            left_partition,
2400            right_partition,
2401            vec![left_sorted],
2402            vec![right_sorted],
2403        )?;
2404
2405        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2406
2407        let intermediate_schema = Schema::new(vec![
2408            Field::new("0", DataType::Int32, true),
2409            Field::new("1", DataType::Int32, true),
2410            Field::new("2", DataType::Int32, true),
2411        ]);
2412        let filter_expr = complicated_filter(&intermediate_schema)?;
2413        let column_indices = vec![
2414            ColumnIndex {
2415                index: 0,
2416                side: JoinSide::Left,
2417            },
2418            ColumnIndex {
2419                index: 4,
2420                side: JoinSide::Left,
2421            },
2422            ColumnIndex {
2423                index: 0,
2424                side: JoinSide::Right,
2425            },
2426        ];
2427        let filter =
2428            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2429
2430        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2431        Ok(())
2432    }
2433
2434    #[tokio::test(flavor = "multi_thread")]
2435    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
2436        let cardinality = (3, 4);
2437        let join_type = JoinType::Full;
2438
2439        // a + b > c + 10 AND a + b < c + 100
2440        let config = SessionConfig::new().with_repartition_joins(false);
2441        // let session_ctx = SessionContext::with_config(config);
2442        // let task_ctx = session_ctx.task_ctx();
2443        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
2444        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2445        let left_schema = &left_partition[0].schema();
2446        let right_schema = &right_partition[0].schema();
2447        let left_sorted = vec![
2448            [PhysicalSortExpr {
2449                expr: col("la1", left_schema)?,
2450                options: SortOptions::default(),
2451            }]
2452            .into(),
2453            [PhysicalSortExpr {
2454                expr: col("la2", left_schema)?,
2455                options: SortOptions::default(),
2456            }]
2457            .into(),
2458        ];
2459
2460        let right_sorted = [PhysicalSortExpr {
2461            expr: col("ra1", right_schema)?,
2462            options: SortOptions::default(),
2463        }]
2464        .into();
2465
2466        let (left, right) = create_memory_table(
2467            left_partition,
2468            right_partition,
2469            left_sorted,
2470            vec![right_sorted],
2471        )?;
2472
2473        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2474
2475        let intermediate_schema = Schema::new(vec![
2476            Field::new("0", DataType::Int32, true),
2477            Field::new("1", DataType::Int32, true),
2478            Field::new("2", DataType::Int32, true),
2479        ]);
2480        let filter_expr = complicated_filter(&intermediate_schema)?;
2481        let column_indices = vec![
2482            ColumnIndex {
2483                index: 0,
2484                side: JoinSide::Left,
2485            },
2486            ColumnIndex {
2487                index: 4,
2488                side: JoinSide::Left,
2489            },
2490            ColumnIndex {
2491                index: 0,
2492                side: JoinSide::Right,
2493            },
2494        ];
2495        let filter =
2496            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2497
2498        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2499        Ok(())
2500    }
2501
2502    #[rstest]
2503    #[tokio::test(flavor = "multi_thread")]
2504    async fn testing_with_temporal_columns(
2505        #[values(
2506            JoinType::Inner,
2507            JoinType::Left,
2508            JoinType::Right,
2509            JoinType::RightSemi,
2510            JoinType::LeftSemi,
2511            JoinType::LeftAnti,
2512            JoinType::LeftMark,
2513            JoinType::RightAnti,
2514            JoinType::RightMark,
2515            JoinType::Full
2516        )]
2517        join_type: JoinType,
2518        #[values(
2519            (4, 5),
2520            (12, 17),
2521        )]
2522        cardinality: (i32, i32),
2523        #[values(0, 1, 2)] case_expr: usize,
2524    ) -> Result<()> {
2525        let session_config = SessionConfig::new().with_repartition_joins(false);
2526        let task_ctx = TaskContext::default().with_session_config(session_config);
2527        let task_ctx = Arc::new(task_ctx);
2528        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2529
2530        let left_schema = &left_partition[0].schema();
2531        let right_schema = &right_partition[0].schema();
2532        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2533        let left_sorted = [PhysicalSortExpr {
2534            expr: col("lt1", left_schema)?,
2535            options: SortOptions {
2536                descending: false,
2537                nulls_first: true,
2538            },
2539        }]
2540        .into();
2541        let right_sorted = [PhysicalSortExpr {
2542            expr: col("rt1", right_schema)?,
2543            options: SortOptions {
2544                descending: false,
2545                nulls_first: true,
2546            },
2547        }]
2548        .into();
2549        let (left, right) = create_memory_table(
2550            left_partition,
2551            right_partition,
2552            vec![left_sorted],
2553            vec![right_sorted],
2554        )?;
2555        let intermediate_schema = Schema::new(vec![
2556            Field::new(
2557                "left",
2558                DataType::Timestamp(TimeUnit::Millisecond, None),
2559                false,
2560            ),
2561            Field::new(
2562                "right",
2563                DataType::Timestamp(TimeUnit::Millisecond, None),
2564                false,
2565            ),
2566        ]);
2567        let filter_expr = join_expr_tests_fixture_temporal(
2568            case_expr,
2569            col("left", &intermediate_schema)?,
2570            col("right", &intermediate_schema)?,
2571            &intermediate_schema,
2572        )?;
2573        let column_indices = vec![
2574            ColumnIndex {
2575                index: 3,
2576                side: JoinSide::Left,
2577            },
2578            ColumnIndex {
2579                index: 3,
2580                side: JoinSide::Right,
2581            },
2582        ];
2583        let filter =
2584            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2585        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2586        Ok(())
2587    }
2588
2589    #[rstest]
2590    #[tokio::test(flavor = "multi_thread")]
2591    async fn test_with_interval_columns(
2592        #[values(
2593            JoinType::Inner,
2594            JoinType::Left,
2595            JoinType::Right,
2596            JoinType::RightSemi,
2597            JoinType::LeftSemi,
2598            JoinType::LeftAnti,
2599            JoinType::LeftMark,
2600            JoinType::RightAnti,
2601            JoinType::RightMark,
2602            JoinType::Full
2603        )]
2604        join_type: JoinType,
2605        #[values(
2606            (4, 5),
2607            (12, 17),
2608        )]
2609        cardinality: (i32, i32),
2610    ) -> Result<()> {
2611        let session_config = SessionConfig::new().with_repartition_joins(false);
2612        let task_ctx = TaskContext::default().with_session_config(session_config);
2613        let task_ctx = Arc::new(task_ctx);
2614        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2615
2616        let left_schema = &left_partition[0].schema();
2617        let right_schema = &right_partition[0].schema();
2618        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2619        let left_sorted = [PhysicalSortExpr {
2620            expr: col("li1", left_schema)?,
2621            options: SortOptions {
2622                descending: false,
2623                nulls_first: true,
2624            },
2625        }]
2626        .into();
2627        let right_sorted = [PhysicalSortExpr {
2628            expr: col("ri1", right_schema)?,
2629            options: SortOptions {
2630                descending: false,
2631                nulls_first: true,
2632            },
2633        }]
2634        .into();
2635        let (left, right) = create_memory_table(
2636            left_partition,
2637            right_partition,
2638            vec![left_sorted],
2639            vec![right_sorted],
2640        )?;
2641        let intermediate_schema = Schema::new(vec![
2642            Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
2643            Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
2644        ]);
2645        let filter_expr = join_expr_tests_fixture_temporal(
2646            0,
2647            col("left", &intermediate_schema)?,
2648            col("right", &intermediate_schema)?,
2649            &intermediate_schema,
2650        )?;
2651        let column_indices = vec![
2652            ColumnIndex {
2653                index: 9,
2654                side: JoinSide::Left,
2655            },
2656            ColumnIndex {
2657                index: 9,
2658                side: JoinSide::Right,
2659            },
2660        ];
2661        let filter =
2662            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2663        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2664
2665        Ok(())
2666    }
2667
2668    #[rstest]
2669    #[tokio::test(flavor = "multi_thread")]
2670    async fn testing_ascending_float_pruning(
2671        #[values(
2672            JoinType::Inner,
2673            JoinType::Left,
2674            JoinType::Right,
2675            JoinType::RightSemi,
2676            JoinType::LeftSemi,
2677            JoinType::LeftAnti,
2678            JoinType::LeftMark,
2679            JoinType::RightAnti,
2680            JoinType::RightMark,
2681            JoinType::Full
2682        )]
2683        join_type: JoinType,
2684        #[values(
2685            (4, 5),
2686            (12, 17),
2687        )]
2688        cardinality: (i32, i32),
2689        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2690    ) -> Result<()> {
2691        let session_config = SessionConfig::new().with_repartition_joins(false);
2692        let task_ctx = TaskContext::default().with_session_config(session_config);
2693        let task_ctx = Arc::new(task_ctx);
2694        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2695
2696        let left_schema = &left_partition[0].schema();
2697        let right_schema = &right_partition[0].schema();
2698        let left_sorted = [PhysicalSortExpr {
2699            expr: col("l_float", left_schema)?,
2700            options: SortOptions::default(),
2701        }]
2702        .into();
2703        let right_sorted = [PhysicalSortExpr {
2704            expr: col("r_float", right_schema)?,
2705            options: SortOptions::default(),
2706        }]
2707        .into();
2708        let (left, right) = create_memory_table(
2709            left_partition,
2710            right_partition,
2711            vec![left_sorted],
2712            vec![right_sorted],
2713        )?;
2714
2715        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2716
2717        let intermediate_schema = Schema::new(vec![
2718            Field::new("left", DataType::Float64, true),
2719            Field::new("right", DataType::Float64, true),
2720        ]);
2721        let filter_expr = join_expr_tests_fixture_f64(
2722            case_expr,
2723            col("left", &intermediate_schema)?,
2724            col("right", &intermediate_schema)?,
2725        );
2726        let column_indices = vec![
2727            ColumnIndex {
2728                index: 10, // l_float
2729                side: JoinSide::Left,
2730            },
2731            ColumnIndex {
2732                index: 10, // r_float
2733                side: JoinSide::Right,
2734            },
2735        ];
2736        let filter =
2737            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2738
2739        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2740        Ok(())
2741    }
2742}