Skip to main content

datafusion_physical_plan/joins/
symmetric_hash_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the symmetric hash join algorithm with range-based
19//! data pruning to join two (potentially infinite) streams.
20//!
21//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate
22//! output ordering) and produces the join output according to the given join
23//! type and other options.
24//!
25//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations
26//! for both its children.
27
28use std::any::Any;
29use std::fmt::{self, Debug};
30use std::mem::{size_of, size_of_val};
31use std::sync::Arc;
32use std::task::{Context, Poll};
33use std::vec;
34
35use crate::common::SharedMemoryReservation;
36use crate::execution_plan::{boundedness_from_children, emission_type_from_children};
37use crate::joins::stream_join_utils::{
38    PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
39    calculate_filter_expr_intervals, combine_two_batches,
40    convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
41    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
42};
43use crate::joins::utils::{
44    BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn,
45    JoinOnRef, NoopBatchTransformer, StatefulStreamResult, apply_join_filter_to_indices,
46    build_batch_from_indices, build_join_schema, check_join_is_valid, equal_rows_arr,
47    symmetric_join_output_partitioning, update_hash,
48};
49use crate::projection::{
50    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
51    physical_to_column_exprs, update_join_filter, update_join_on,
52};
53use crate::{
54    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
55    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
56    joins::StreamJoinPartitionMode,
57    metrics::{ExecutionPlanMetricsSet, MetricsSet},
58};
59
60use arrow::array::{
61    ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
62    UInt64Array,
63};
64use arrow::compute::concat_batches;
65use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
66use arrow::record_batch::RecordBatch;
67use datafusion_common::hash_utils::create_hashes;
68use datafusion_common::utils::bisect;
69use datafusion_common::{
70    HashSet, JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err,
71    plan_err,
72};
73use datafusion_execution::TaskContext;
74use datafusion_execution::memory_pool::MemoryConsumer;
75use datafusion_expr::interval_arithmetic::Interval;
76use datafusion_physical_expr::equivalence::join_equivalence_properties;
77use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
78use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
79use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
80
81use ahash::RandomState;
82use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
83use futures::{Stream, StreamExt, ready};
84use parking_lot::Mutex;
85
86const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
87
88/// A symmetric hash join with range conditions is when both streams are hashed on the
89/// join key and the resulting hash tables are used to join the streams.
90/// The join is considered symmetric because the hash table is built on the join keys from both
91/// streams, and the matching of rows is based on the values of the join keys in both streams.
92/// This type of join is efficient in streaming context as it allows for fast lookups in the hash
93/// table, rather than having to scan through one or both of the streams to find matching rows, also it
94/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions),
95/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming
96/// data without any memory issues.
97///
98/// For each input stream, create a hash table.
99///   - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets.
100///   - Test if input is equal to a predefined set of other inputs.
101///   - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch].
102///   - Try to prune other side (probe) with new [RecordBatch].
103///   - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.),
104///     output the [RecordBatch] when a pruning happens or at the end of the data.
105///
106///
107/// ``` text
108///                        +-------------------------+
109///                        |                         |
110///   left stream ---------|  Left OneSideHashJoiner |---+
111///                        |                         |   |
112///                        +-------------------------+   |
113///                                                      |
114///                                                      |--------- Joined output
115///                                                      |
116///                        +-------------------------+   |
117///                        |                         |   |
118///  right stream ---------| Right OneSideHashJoiner |---+
119///                        |                         |
120///                        +-------------------------+
121///
122/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic
123/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range.
124///
125///
126///               PROBE SIDE          BUILD SIDE
127///                 BUFFER              BUFFER
128///             +-------------+     +------------+
129///             |             |     |            |    Unjoinable
130///             |             |     |            |    Range
131///             |             |     |            |
132///             |             |  |---------------------------------
133///             |             |  |  |            |
134///             |             |  |  |            |
135///             |             | /   |            |
136///             |             | |   |            |
137///             |             | |   |            |
138///             |             | |   |            |
139///             |             | |   |            |
140///             |             | |   |            |    Joinable
141///             |             |/    |            |    Range
142///             |             ||    |            |
143///             |+-----------+||    |            |
144///             || Record    ||     |            |
145///             || Batch     ||     |            |
146///             |+-----------+||    |            |
147///             +-------------+\    +------------+
148///                             |
149///                             \
150///                              |---------------------------------
151///
152///  This happens when range conditions are provided on sorted columns. E.g.
153///
154///        SELECT * FROM left_table, right_table
155///        ON
156///          left_key = right_key AND
157///          left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR
158///
159/// or
160///       SELECT * FROM left_table, right_table
161///        ON
162///          left_key = right_key AND
163///          left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10
164///
165/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to
166/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the
167/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios)
168/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning
169/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" ,
170/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending)
171/// than that can be dropped from the inner buffer.
172/// ```
173#[derive(Debug, Clone)]
174pub struct SymmetricHashJoinExec {
175    /// Left side stream
176    pub(crate) left: Arc<dyn ExecutionPlan>,
177    /// Right side stream
178    pub(crate) right: Arc<dyn ExecutionPlan>,
179    /// Set of common columns used to join on
180    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
181    /// Filters applied when finding matching rows
182    pub(crate) filter: Option<JoinFilter>,
183    /// How the join is performed
184    pub(crate) join_type: JoinType,
185    /// Shares the `RandomState` for the hashing algorithm
186    random_state: RandomState,
187    /// Execution metrics
188    metrics: ExecutionPlanMetricsSet,
189    /// Information of index and left / right placement of columns
190    column_indices: Vec<ColumnIndex>,
191    /// Defines the null equality for the join.
192    pub(crate) null_equality: NullEquality,
193    /// Left side sort expression(s)
194    pub(crate) left_sort_exprs: Option<LexOrdering>,
195    /// Right side sort expression(s)
196    pub(crate) right_sort_exprs: Option<LexOrdering>,
197    /// Partition Mode
198    mode: StreamJoinPartitionMode,
199    /// Cache holding plan properties like equivalences, output partitioning etc.
200    cache: PlanProperties,
201}
202
203impl SymmetricHashJoinExec {
204    /// Tries to create a new [SymmetricHashJoinExec].
205    /// # Error
206    /// This function errors when:
207    /// - It is not possible to join the left and right sides on keys `on`, or
208    /// - It fails to construct `SortedFilterExpr`s, or
209    /// - It fails to create the [ExprIntervalGraph].
210    #[expect(clippy::too_many_arguments)]
211    pub fn try_new(
212        left: Arc<dyn ExecutionPlan>,
213        right: Arc<dyn ExecutionPlan>,
214        on: JoinOn,
215        filter: Option<JoinFilter>,
216        join_type: &JoinType,
217        null_equality: NullEquality,
218        left_sort_exprs: Option<LexOrdering>,
219        right_sort_exprs: Option<LexOrdering>,
220        mode: StreamJoinPartitionMode,
221    ) -> Result<Self> {
222        let left_schema = left.schema();
223        let right_schema = right.schema();
224
225        // Error out if no "on" constraints are given:
226        if on.is_empty() {
227            return plan_err!(
228                "On constraints in SymmetricHashJoinExec should be non-empty"
229            );
230        }
231
232        // Check if the join is valid with the given on constraints:
233        check_join_is_valid(&left_schema, &right_schema, &on)?;
234
235        // Build the join schema from the left and right schemas:
236        let (schema, column_indices) =
237            build_join_schema(&left_schema, &right_schema, join_type);
238
239        // Initialize the random state for the join operation:
240        let random_state = RandomState::with_seeds(0, 0, 0, 0);
241        let schema = Arc::new(schema);
242        let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?;
243        Ok(SymmetricHashJoinExec {
244            left,
245            right,
246            on,
247            filter,
248            join_type: *join_type,
249            random_state,
250            metrics: ExecutionPlanMetricsSet::new(),
251            column_indices,
252            null_equality,
253            left_sort_exprs,
254            right_sort_exprs,
255            mode,
256            cache,
257        })
258    }
259
260    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
261    fn compute_properties(
262        left: &Arc<dyn ExecutionPlan>,
263        right: &Arc<dyn ExecutionPlan>,
264        schema: SchemaRef,
265        join_type: JoinType,
266        join_on: JoinOnRef,
267    ) -> Result<PlanProperties> {
268        // Calculate equivalence properties:
269        let eq_properties = join_equivalence_properties(
270            left.equivalence_properties().clone(),
271            right.equivalence_properties().clone(),
272            &join_type,
273            schema,
274            &[false, false],
275            // Has alternating probe side
276            None,
277            join_on,
278        )?;
279
280        let output_partitioning =
281            symmetric_join_output_partitioning(left, right, &join_type)?;
282
283        Ok(PlanProperties::new(
284            eq_properties,
285            output_partitioning,
286            emission_type_from_children([left, right]),
287            boundedness_from_children([left, right]),
288        ))
289    }
290
291    /// left stream
292    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
293        &self.left
294    }
295
296    /// right stream
297    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
298        &self.right
299    }
300
301    /// Set of common columns used to join on
302    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
303        &self.on
304    }
305
306    /// Filters applied before join output
307    pub fn filter(&self) -> Option<&JoinFilter> {
308        self.filter.as_ref()
309    }
310
311    /// How the join is performed
312    pub fn join_type(&self) -> &JoinType {
313        &self.join_type
314    }
315
316    /// Get null_equality
317    pub fn null_equality(&self) -> NullEquality {
318        self.null_equality
319    }
320
321    /// Get partition mode
322    pub fn partition_mode(&self) -> StreamJoinPartitionMode {
323        self.mode
324    }
325
326    /// Get left_sort_exprs
327    pub fn left_sort_exprs(&self) -> Option<&LexOrdering> {
328        self.left_sort_exprs.as_ref()
329    }
330
331    /// Get right_sort_exprs
332    pub fn right_sort_exprs(&self) -> Option<&LexOrdering> {
333        self.right_sort_exprs.as_ref()
334    }
335
336    /// Check if order information covers every column in the filter expression.
337    pub fn check_if_order_information_available(&self) -> Result<bool> {
338        if let Some(filter) = self.filter() {
339            let left = self.left();
340            if let Some(left_ordering) = left.output_ordering() {
341                let right = self.right();
342                if let Some(right_ordering) = right.output_ordering() {
343                    let left_convertible = convert_sort_expr_with_filter_schema(
344                        &JoinSide::Left,
345                        filter,
346                        &left.schema(),
347                        &left_ordering[0],
348                    )?
349                    .is_some();
350                    let right_convertible = convert_sort_expr_with_filter_schema(
351                        &JoinSide::Right,
352                        filter,
353                        &right.schema(),
354                        &right_ordering[0],
355                    )?
356                    .is_some();
357                    return Ok(left_convertible && right_convertible);
358                }
359            }
360        }
361        Ok(false)
362    }
363}
364
365impl DisplayAs for SymmetricHashJoinExec {
366    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
367        match t {
368            DisplayFormatType::Default | DisplayFormatType::Verbose => {
369                let display_filter = self.filter.as_ref().map_or_else(
370                    || "".to_string(),
371                    |f| format!(", filter={}", f.expression()),
372                );
373                let on = self
374                    .on
375                    .iter()
376                    .map(|(c1, c2)| format!("({c1}, {c2})"))
377                    .collect::<Vec<String>>()
378                    .join(", ");
379                write!(
380                    f,
381                    "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
382                    self.mode, self.join_type, on, display_filter
383                )
384            }
385            DisplayFormatType::TreeRender => {
386                let on = self
387                    .on
388                    .iter()
389                    .map(|(c1, c2)| {
390                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
391                    })
392                    .collect::<Vec<String>>()
393                    .join(", ");
394
395                writeln!(f, "mode={:?}", self.mode)?;
396                if *self.join_type() != JoinType::Inner {
397                    writeln!(f, "join_type={:?}", self.join_type)?;
398                }
399                writeln!(f, "on={on}")
400            }
401        }
402    }
403}
404
405impl ExecutionPlan for SymmetricHashJoinExec {
406    fn name(&self) -> &'static str {
407        "SymmetricHashJoinExec"
408    }
409
410    fn as_any(&self) -> &dyn Any {
411        self
412    }
413
414    fn properties(&self) -> &PlanProperties {
415        &self.cache
416    }
417
418    fn required_input_distribution(&self) -> Vec<Distribution> {
419        match self.mode {
420            StreamJoinPartitionMode::Partitioned => {
421                let (left_expr, right_expr) = self
422                    .on
423                    .iter()
424                    .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
425                    .unzip();
426                vec![
427                    Distribution::HashPartitioned(left_expr),
428                    Distribution::HashPartitioned(right_expr),
429                ]
430            }
431            StreamJoinPartitionMode::SinglePartition => {
432                vec![Distribution::SinglePartition, Distribution::SinglePartition]
433            }
434        }
435    }
436
437    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
438        vec![
439            self.left_sort_exprs
440                .as_ref()
441                .map(|e| OrderingRequirements::from(e.clone())),
442            self.right_sort_exprs
443                .as_ref()
444                .map(|e| OrderingRequirements::from(e.clone())),
445        ]
446    }
447
448    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
449        vec![&self.left, &self.right]
450    }
451
452    fn with_new_children(
453        self: Arc<Self>,
454        children: Vec<Arc<dyn ExecutionPlan>>,
455    ) -> Result<Arc<dyn ExecutionPlan>> {
456        Ok(Arc::new(SymmetricHashJoinExec::try_new(
457            Arc::clone(&children[0]),
458            Arc::clone(&children[1]),
459            self.on.clone(),
460            self.filter.clone(),
461            &self.join_type,
462            self.null_equality,
463            self.left_sort_exprs.clone(),
464            self.right_sort_exprs.clone(),
465            self.mode,
466        )?))
467    }
468
469    fn metrics(&self) -> Option<MetricsSet> {
470        Some(self.metrics.clone_inner())
471    }
472
473    fn statistics(&self) -> Result<Statistics> {
474        // TODO stats: it is not possible in general to know the output size of joins
475        Ok(Statistics::new_unknown(&self.schema()))
476    }
477
478    fn execute(
479        &self,
480        partition: usize,
481        context: Arc<TaskContext>,
482    ) -> Result<SendableRecordBatchStream> {
483        let left_partitions = self.left.output_partitioning().partition_count();
484        let right_partitions = self.right.output_partitioning().partition_count();
485        assert_eq_or_internal_err!(
486            left_partitions,
487            right_partitions,
488            "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
489                 consider using RepartitionExec"
490        );
491        // If `filter_state` and `filter` are both present, then calculate sorted
492        // filter expressions for both sides, and build an expression graph.
493        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
494            self.left_sort_exprs(),
495            self.right_sort_exprs(),
496            &self.filter,
497        ) {
498            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
499                let (left, right, graph) = prepare_sorted_exprs(
500                    filter,
501                    &self.left,
502                    &self.right,
503                    left_sort_exprs,
504                    right_sort_exprs,
505                )?;
506                (Some(left), Some(right), Some(graph))
507            }
508            // If `filter_state` or `filter` is not present, then return None
509            // for all three values:
510            _ => (None, None, None),
511        };
512
513        let (on_left, on_right) = self.on.iter().cloned().unzip();
514
515        let left_side_joiner =
516            OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
517        let right_side_joiner =
518            OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
519
520        let left_stream = self.left.execute(partition, Arc::clone(&context))?;
521
522        let right_stream = self.right.execute(partition, Arc::clone(&context))?;
523
524        let batch_size = context.session_config().batch_size();
525        let enforce_batch_size_in_joins =
526            context.session_config().enforce_batch_size_in_joins();
527
528        let reservation = Arc::new(Mutex::new(
529            MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
530                .register(context.memory_pool()),
531        ));
532        if let Some(g) = graph.as_ref() {
533            reservation.lock().try_grow(g.size())?;
534        }
535
536        if enforce_batch_size_in_joins {
537            Ok(Box::pin(SymmetricHashJoinStream {
538                left_stream,
539                right_stream,
540                schema: self.schema(),
541                filter: self.filter.clone(),
542                join_type: self.join_type,
543                random_state: self.random_state.clone(),
544                left: left_side_joiner,
545                right: right_side_joiner,
546                column_indices: self.column_indices.clone(),
547                metrics: StreamJoinMetrics::new(partition, &self.metrics),
548                graph,
549                left_sorted_filter_expr,
550                right_sorted_filter_expr,
551                null_equality: self.null_equality,
552                state: SHJStreamState::PullRight,
553                reservation,
554                batch_transformer: BatchSplitter::new(batch_size),
555            }))
556        } else {
557            Ok(Box::pin(SymmetricHashJoinStream {
558                left_stream,
559                right_stream,
560                schema: self.schema(),
561                filter: self.filter.clone(),
562                join_type: self.join_type,
563                random_state: self.random_state.clone(),
564                left: left_side_joiner,
565                right: right_side_joiner,
566                column_indices: self.column_indices.clone(),
567                metrics: StreamJoinMetrics::new(partition, &self.metrics),
568                graph,
569                left_sorted_filter_expr,
570                right_sorted_filter_expr,
571                null_equality: self.null_equality,
572                state: SHJStreamState::PullRight,
573                reservation,
574                batch_transformer: NoopBatchTransformer::new(),
575            }))
576        }
577    }
578
579    /// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done,
580    /// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan.
581    /// Otherwise, it returns None.
582    fn try_swapping_with_projection(
583        &self,
584        projection: &ProjectionExec,
585    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
586        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
587        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
588        else {
589            return Ok(None);
590        };
591
592        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
593            self.left().schema().fields().len(),
594            &projection_as_columns,
595        );
596
597        if !join_allows_pushdown(
598            &projection_as_columns,
599            &self.schema(),
600            far_right_left_col_ind,
601            far_left_right_col_ind,
602        ) {
603            return Ok(None);
604        }
605
606        let Some(new_on) = update_join_on(
607            &projection_as_columns[0..=far_right_left_col_ind as _],
608            &projection_as_columns[far_left_right_col_ind as _..],
609            self.on(),
610            self.left().schema().fields().len(),
611        ) else {
612            return Ok(None);
613        };
614
615        let new_filter = if let Some(filter) = self.filter() {
616            match update_join_filter(
617                &projection_as_columns[0..=far_right_left_col_ind as _],
618                &projection_as_columns[far_left_right_col_ind as _..],
619                filter,
620                self.left().schema().fields().len(),
621            ) {
622                Some(updated_filter) => Some(updated_filter),
623                None => return Ok(None),
624            }
625        } else {
626            None
627        };
628
629        let (new_left, new_right) = new_join_children(
630            &projection_as_columns,
631            far_right_left_col_ind,
632            far_left_right_col_ind,
633            self.left(),
634            self.right(),
635        )?;
636
637        SymmetricHashJoinExec::try_new(
638            Arc::new(new_left),
639            Arc::new(new_right),
640            new_on,
641            new_filter,
642            self.join_type(),
643            self.null_equality(),
644            self.right().output_ordering().cloned(),
645            self.left().output_ordering().cloned(),
646            self.partition_mode(),
647        )
648        .map(|e| Some(Arc::new(e) as _))
649    }
650}
651
652/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
653struct SymmetricHashJoinStream<T> {
654    /// Input streams
655    left_stream: SendableRecordBatchStream,
656    right_stream: SendableRecordBatchStream,
657    /// Input schema
658    schema: Arc<Schema>,
659    /// join filter
660    filter: Option<JoinFilter>,
661    /// type of the join
662    join_type: JoinType,
663    // left hash joiner
664    left: OneSideHashJoiner,
665    /// right hash joiner
666    right: OneSideHashJoiner,
667    /// Information of index and left / right placement of columns
668    column_indices: Vec<ColumnIndex>,
669    // Expression graph for range pruning.
670    graph: Option<ExprIntervalGraph>,
671    // Left globally sorted filter expr
672    left_sorted_filter_expr: Option<SortedFilterExpr>,
673    // Right globally sorted filter expr
674    right_sorted_filter_expr: Option<SortedFilterExpr>,
675    /// Random state used for hashing initialization
676    random_state: RandomState,
677    /// Defines the null equality for the join.
678    null_equality: NullEquality,
679    /// Metrics
680    metrics: StreamJoinMetrics,
681    /// Memory reservation
682    reservation: SharedMemoryReservation,
683    /// State machine for input execution
684    state: SHJStreamState,
685    /// Transforms the output batch before returning.
686    batch_transformer: T,
687}
688
689impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
690    for SymmetricHashJoinStream<T>
691{
692    fn schema(&self) -> SchemaRef {
693        Arc::clone(&self.schema)
694    }
695}
696
697impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> {
698    type Item = Result<RecordBatch>;
699
700    fn poll_next(
701        mut self: std::pin::Pin<&mut Self>,
702        cx: &mut Context<'_>,
703    ) -> Poll<Option<Self::Item>> {
704        self.poll_next_impl(cx)
705    }
706}
707
708/// Determine the pruning length for `buffer`.
709///
710/// This function evaluates the build side filter expression, converts the
711/// result into an array and determines the pruning length by performing a
712/// binary search on the array.
713///
714/// # Arguments
715///
716/// * `buffer`: The record batch to be pruned.
717/// * `build_side_filter_expr`: The filter expression on the build side used
718///   to determine the pruning length.
719///
720/// # Returns
721///
722/// A [Result] object that contains the pruning length. The function will return
723/// an error if
724/// - there is an issue evaluating the build side filter expression;
725/// - there is an issue converting the build side filter expression into an array
726fn determine_prune_length(
727    buffer: &RecordBatch,
728    build_side_filter_expr: &SortedFilterExpr,
729) -> Result<usize> {
730    let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
731    let interval = build_side_filter_expr.interval();
732    // Evaluate the build side filter expression and convert it into an array
733    let batch_arr = origin_sorted_expr
734        .expr
735        .evaluate(buffer)?
736        .into_array(buffer.num_rows())?;
737
738    // Get the lower or upper interval based on the sort direction
739    let target = if origin_sorted_expr.options.descending {
740        interval.upper().clone()
741    } else {
742        interval.lower().clone()
743    };
744
745    // Perform binary search on the array to determine the length of the record batch to be pruned
746    bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
747}
748
749/// This method determines if the result of the join should be produced in the final step or not.
750///
751/// # Arguments
752///
753/// * `build_side` - Enum indicating the side of the join used as the build side.
754/// * `join_type` - Enum indicating the type of join to be performed.
755///
756/// # Returns
757///
758/// A boolean indicating whether the result of the join should be produced in the final step or not.
759/// The result will be true if the build side is JoinSide::Left and the join type is one of
760/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi.
761/// If the build side is JoinSide::Right, the result will be true if the join type
762/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi.
763fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
764    if build_side == JoinSide::Left {
765        matches!(
766            join_type,
767            JoinType::Left
768                | JoinType::LeftAnti
769                | JoinType::Full
770                | JoinType::LeftSemi
771                | JoinType::LeftMark
772        )
773    } else {
774        matches!(
775            join_type,
776            JoinType::Right
777                | JoinType::RightAnti
778                | JoinType::Full
779                | JoinType::RightSemi
780                | JoinType::RightMark
781        )
782    }
783}
784
785/// Calculate indices by join type.
786///
787/// This method returns a tuple of two arrays: build and probe indices.
788/// The length of both arrays will be the same.
789///
790/// # Arguments
791///
792/// * `build_side`: Join side which defines the build side.
793/// * `prune_length`: Length of the prune data.
794/// * `visited_rows`: Hash set of visited rows of the build side.
795/// * `deleted_offset`: Deleted offset of the build side.
796/// * `join_type`: The type of join to be performed.
797///
798/// # Returns
799///
800/// A tuple of two arrays of primitive types representing the build and probe indices.
801fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
802    build_side: JoinSide,
803    prune_length: usize,
804    visited_rows: &HashSet<usize>,
805    deleted_offset: usize,
806    join_type: JoinType,
807) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
808where
809    NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
810{
811    // Store the result in a tuple
812    let result = match (build_side, join_type) {
813        // For a mark join we “mark” each build‐side row with a dummy 0 in the probe‐side index
814        // if it ever matched. For example, if
815        //
816        // prune_length = 5
817        // deleted_offset = 0
818        // visited_rows = {1, 3}
819        //
820        // then we produce:
821        //
822        // build_indices = [0, 1, 2, 3, 4]
823        // probe_indices = [None, Some(0), None, Some(0), None]
824        //
825        // Example: for each build row i in [0..5):
826        //   – We always output its own index i in `build_indices`
827        //   – We output `Some(0)` in `probe_indices[i]` if row i was ever visited, else `None`
828        (JoinSide::Left, JoinType::LeftMark) => {
829            let build_indices = (0..prune_length)
830                .map(L::Native::from_usize)
831                .collect::<PrimitiveArray<L>>();
832            let probe_indices = (0..prune_length)
833                .map(|idx| {
834                    // For mark join we output a dummy index 0 to indicate the row had a match
835                    visited_rows
836                        .contains(&(idx + deleted_offset))
837                        .then_some(R::Native::from_usize(0).unwrap())
838                })
839                .collect();
840            (build_indices, probe_indices)
841        }
842        (JoinSide::Right, JoinType::RightMark) => {
843            let build_indices = (0..prune_length)
844                .map(L::Native::from_usize)
845                .collect::<PrimitiveArray<L>>();
846            let probe_indices = (0..prune_length)
847                .map(|idx| {
848                    // For mark join we output a dummy index 0 to indicate the row had a match
849                    visited_rows
850                        .contains(&(idx + deleted_offset))
851                        .then_some(R::Native::from_usize(0).unwrap())
852                })
853                .collect();
854            (build_indices, probe_indices)
855        }
856        // In the case of `Left` or `Right` join, or `Full` join, get the anti indices
857        (JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
858        | (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
859        | (_, JoinType::Full) => {
860            let build_unmatched_indices =
861                get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
862            let mut builder =
863                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
864            builder.append_nulls(build_unmatched_indices.len());
865            let probe_indices = builder.finish();
866            (build_unmatched_indices, probe_indices)
867        }
868        // In the case of `LeftSemi` or `RightSemi` join, get the semi indices
869        (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
870            let build_unmatched_indices =
871                get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
872            let mut builder =
873                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
874            builder.append_nulls(build_unmatched_indices.len());
875            let probe_indices = builder.finish();
876            (build_unmatched_indices, probe_indices)
877        }
878        // The case of other join types is not considered
879        _ => unreachable!(),
880    };
881    Ok(result)
882}
883
884/// This function produces unmatched record results based on the build side,
885/// join type and other parameters.
886///
887/// The method uses first `prune_length` rows from the build side input buffer
888/// to produce results.
889///
890/// # Arguments
891///
892/// * `output_schema` - The schema of the final output record batch.
893/// * `prune_length` - The length of the determined prune length.
894/// * `probe_schema` - The schema of the probe [RecordBatch].
895/// * `join_type` - The type of join to be performed.
896/// * `column_indices` - Indices of columns that are being joined.
897///
898/// # Returns
899///
900/// * `Option<RecordBatch>` - The final output record batch if required, otherwise [None].
901pub(crate) fn build_side_determined_results(
902    build_hash_joiner: &OneSideHashJoiner,
903    output_schema: &SchemaRef,
904    prune_length: usize,
905    probe_schema: SchemaRef,
906    join_type: JoinType,
907    column_indices: &[ColumnIndex],
908) -> Result<Option<RecordBatch>> {
909    // Check if we need to produce a result in the final output:
910    if prune_length > 0
911        && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
912    {
913        // Calculate the indices for build and probe sides based on join type and build side:
914        let (build_indices, probe_indices) = calculate_indices_by_join_type(
915            build_hash_joiner.build_side,
916            prune_length,
917            &build_hash_joiner.visited_rows,
918            build_hash_joiner.deleted_offset,
919            join_type,
920        )?;
921
922        // Create an empty probe record batch:
923        let empty_probe_batch = RecordBatch::new_empty(probe_schema);
924        // Build the final result from the indices of build and probe sides:
925        build_batch_from_indices(
926            output_schema.as_ref(),
927            &build_hash_joiner.input_buffer,
928            &empty_probe_batch,
929            &build_indices,
930            &probe_indices,
931            column_indices,
932            build_hash_joiner.build_side,
933            join_type,
934        )
935        .map(|batch| (batch.num_rows() > 0).then_some(batch))
936    } else {
937        // If we don't need to produce a result, return None
938        Ok(None)
939    }
940}
941
942/// This method performs a join between the build side input buffer and the probe side batch.
943///
944/// # Arguments
945///
946/// * `build_hash_joiner` - Build side hash joiner
947/// * `probe_hash_joiner` - Probe side hash joiner
948/// * `schema` - A reference to the schema of the output record batch.
949/// * `join_type` - The type of join to be performed.
950/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
951/// * `filter` - An optional filter on the join condition.
952/// * `probe_batch` - The second record batch to be joined.
953/// * `column_indices` - An array of columns to be selected for the result of the join.
954/// * `random_state` - The random state for the join.
955/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
956///
957/// # Returns
958///
959/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
960/// If the join type is one of the above four, the function will return [None].
961#[expect(clippy::too_many_arguments)]
962pub(crate) fn join_with_probe_batch(
963    build_hash_joiner: &mut OneSideHashJoiner,
964    probe_hash_joiner: &mut OneSideHashJoiner,
965    schema: &SchemaRef,
966    join_type: JoinType,
967    filter: Option<&JoinFilter>,
968    probe_batch: &RecordBatch,
969    column_indices: &[ColumnIndex],
970    random_state: &RandomState,
971    null_equality: NullEquality,
972) -> Result<Option<RecordBatch>> {
973    if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
974        return Ok(None);
975    }
976    let (build_indices, probe_indices) = lookup_join_hashmap(
977        &build_hash_joiner.hashmap,
978        &build_hash_joiner.input_buffer,
979        probe_batch,
980        &build_hash_joiner.on,
981        &probe_hash_joiner.on,
982        random_state,
983        null_equality,
984        &mut build_hash_joiner.hashes_buffer,
985        Some(build_hash_joiner.deleted_offset),
986    )?;
987
988    let (build_indices, probe_indices) = if let Some(filter) = filter {
989        apply_join_filter_to_indices(
990            &build_hash_joiner.input_buffer,
991            probe_batch,
992            build_indices,
993            probe_indices,
994            filter,
995            build_hash_joiner.build_side,
996            None,
997            join_type,
998        )?
999    } else {
1000        (build_indices, probe_indices)
1001    };
1002
1003    if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
1004        record_visited_indices(
1005            &mut build_hash_joiner.visited_rows,
1006            build_hash_joiner.deleted_offset,
1007            &build_indices,
1008        );
1009    }
1010    if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
1011        record_visited_indices(
1012            &mut probe_hash_joiner.visited_rows,
1013            probe_hash_joiner.offset,
1014            &probe_indices,
1015        );
1016    }
1017    if matches!(
1018        join_type,
1019        JoinType::LeftAnti
1020            | JoinType::RightAnti
1021            | JoinType::LeftSemi
1022            | JoinType::LeftMark
1023            | JoinType::RightSemi
1024            | JoinType::RightMark
1025    ) {
1026        Ok(None)
1027    } else {
1028        build_batch_from_indices(
1029            schema,
1030            &build_hash_joiner.input_buffer,
1031            probe_batch,
1032            &build_indices,
1033            &probe_indices,
1034            column_indices,
1035            build_hash_joiner.build_side,
1036            join_type,
1037        )
1038        .map(|batch| (batch.num_rows() > 0).then_some(batch))
1039    }
1040}
1041
1042/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential
1043/// hash collisions.
1044///
1045/// # Arguments
1046///
1047/// * `build_hashmap` - hashmap collected from build side data.
1048/// * `build_batch` - Build side record batch.
1049/// * `probe_batch` - Probe side record batch.
1050/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join.
1051/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
1052/// * `random_state` - The random state for the join.
1053/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining.
1054/// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
1055/// * `deleted_offset` - deleted offset for build side data.
1056///
1057/// # Returns
1058///
1059/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side,
1060/// matched by join key columns.
1061#[expect(clippy::too_many_arguments)]
1062fn lookup_join_hashmap(
1063    build_hashmap: &PruningJoinHashMap,
1064    build_batch: &RecordBatch,
1065    probe_batch: &RecordBatch,
1066    build_on: &[PhysicalExprRef],
1067    probe_on: &[PhysicalExprRef],
1068    random_state: &RandomState,
1069    null_equality: NullEquality,
1070    hashes_buffer: &mut Vec<u64>,
1071    deleted_offset: Option<usize>,
1072) -> Result<(UInt64Array, UInt32Array)> {
1073    let keys_values = evaluate_expressions_to_arrays(probe_on, probe_batch)?;
1074    let build_join_values = evaluate_expressions_to_arrays(build_on, build_batch)?;
1075
1076    hashes_buffer.clear();
1077    hashes_buffer.resize(probe_batch.num_rows(), 0);
1078    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1079
1080    // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm
1081    // will return build indices for each probe row in a reverse order as such:
1082    // Build Indices: [5, 4, 3]
1083    // Probe Indices: [1, 1, 1]
1084    //
1085    // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side.
1086    // Let's consider probe rows [0,1] as an example:
1087    //
1088    // When the probe iteration sequence is reversed, the following pairings can be derived:
1089    //
1090    // For probe row 1:
1091    //     (5, 1)
1092    //     (4, 1)
1093    //     (3, 1)
1094    //
1095    // For probe row 0:
1096    //     (5, 0)
1097    //     (4, 0)
1098    //     (3, 0)
1099    //
1100    // After reversing both sets of indices, we obtain reversed indices:
1101    //
1102    //     (3,0)
1103    //     (4,0)
1104    //     (5,0)
1105    //     (3,1)
1106    //     (4,1)
1107    //     (5,1)
1108    //
1109    // With this approach, the lexicographic order on both the probe side and the build side is preserved.
1110    let (mut matched_probe, mut matched_build) = build_hashmap.get_matched_indices(
1111        Box::new(hash_values.iter().enumerate().rev()),
1112        deleted_offset,
1113    );
1114
1115    matched_probe.reverse();
1116    matched_build.reverse();
1117
1118    let build_indices: UInt64Array = matched_build.into();
1119    let probe_indices: UInt32Array = matched_probe.into();
1120
1121    let (build_indices, probe_indices) = equal_rows_arr(
1122        &build_indices,
1123        &probe_indices,
1124        &build_join_values,
1125        &keys_values,
1126        null_equality,
1127    )?;
1128
1129    Ok((build_indices, probe_indices))
1130}
1131
1132pub struct OneSideHashJoiner {
1133    /// Build side
1134    build_side: JoinSide,
1135    /// Input record batch buffer
1136    pub input_buffer: RecordBatch,
1137    /// Columns from the side
1138    pub(crate) on: Vec<PhysicalExprRef>,
1139    /// Hashmap
1140    pub(crate) hashmap: PruningJoinHashMap,
1141    /// Reuse the hashes buffer
1142    pub(crate) hashes_buffer: Vec<u64>,
1143    /// Matched rows
1144    pub(crate) visited_rows: HashSet<usize>,
1145    /// Offset
1146    pub(crate) offset: usize,
1147    /// Deleted offset
1148    pub(crate) deleted_offset: usize,
1149}
1150
1151impl OneSideHashJoiner {
1152    pub fn size(&self) -> usize {
1153        let mut size = 0;
1154        size += size_of_val(self);
1155        size += size_of_val(&self.build_side);
1156        size += self.input_buffer.get_array_memory_size();
1157        size += size_of_val(&self.on);
1158        size += self.hashmap.size();
1159        size += self.hashes_buffer.capacity() * size_of::<u64>();
1160        size += self.visited_rows.capacity() * size_of::<usize>();
1161        size += size_of_val(&self.offset);
1162        size += size_of_val(&self.deleted_offset);
1163        size
1164    }
1165    pub fn new(
1166        build_side: JoinSide,
1167        on: Vec<PhysicalExprRef>,
1168        schema: SchemaRef,
1169    ) -> Self {
1170        Self {
1171            build_side,
1172            input_buffer: RecordBatch::new_empty(schema),
1173            on,
1174            hashmap: PruningJoinHashMap::with_capacity(0),
1175            hashes_buffer: vec![],
1176            visited_rows: HashSet::new(),
1177            offset: 0,
1178            deleted_offset: 0,
1179        }
1180    }
1181
1182    /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch.
1183    ///
1184    /// # Arguments
1185    ///
1186    /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer
1187    /// * `random_state` - The random state used to hash values
1188    ///
1189    /// # Returns
1190    ///
1191    /// Returns a [Result] encapsulating any intermediate errors.
1192    pub(crate) fn update_internal_state(
1193        &mut self,
1194        batch: &RecordBatch,
1195        random_state: &RandomState,
1196    ) -> Result<()> {
1197        // Merge the incoming batch with the existing input buffer:
1198        self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?;
1199        // Resize the hashes buffer to the number of rows in the incoming batch:
1200        self.hashes_buffer.resize(batch.num_rows(), 0);
1201        // Get allocation_info before adding the item
1202        // Update the hashmap with the join key values and hashes of the incoming batch:
1203        update_hash(
1204            &self.on,
1205            batch,
1206            &mut self.hashmap,
1207            self.offset,
1208            random_state,
1209            &mut self.hashes_buffer,
1210            self.deleted_offset,
1211            false,
1212        )?;
1213        Ok(())
1214    }
1215
1216    /// Calculate prune length.
1217    ///
1218    /// # Arguments
1219    ///
1220    /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression..
1221    /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression.
1222    /// * `graph` - A mutable reference to the physical expression graph.
1223    ///
1224    /// # Returns
1225    ///
1226    /// A Result object that contains the pruning length.
1227    pub(crate) fn calculate_prune_length_with_probe_batch(
1228        &mut self,
1229        build_side_sorted_filter_expr: &mut SortedFilterExpr,
1230        probe_side_sorted_filter_expr: &mut SortedFilterExpr,
1231        graph: &mut ExprIntervalGraph,
1232    ) -> Result<usize> {
1233        // Return early if the input buffer is empty:
1234        if self.input_buffer.num_rows() == 0 {
1235            return Ok(0);
1236        }
1237        // Process the build and probe side sorted filter expressions if both are present:
1238        // Collect the sorted filter expressions into a vector of (node_index, interval) tuples:
1239        let mut filter_intervals = vec![];
1240        for expr in [
1241            &build_side_sorted_filter_expr,
1242            &probe_side_sorted_filter_expr,
1243        ] {
1244            filter_intervals.push((expr.node_index(), expr.interval().clone()))
1245        }
1246        // Update the physical expression graph using the join filter intervals:
1247        graph.update_ranges(&mut filter_intervals, Interval::TRUE)?;
1248        // Extract the new join filter interval for the build side:
1249        let calculated_build_side_interval = filter_intervals.remove(0).1;
1250        // If the intervals have not changed, return early without pruning:
1251        if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
1252            return Ok(0);
1253        }
1254        // Update the build side interval and determine the pruning length:
1255        build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
1256
1257        determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
1258    }
1259
1260    pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
1261        // Prune the hash values:
1262        self.hashmap.prune_hash_values(
1263            prune_length,
1264            self.deleted_offset as u64,
1265            HASHMAP_SHRINK_SCALE_FACTOR,
1266        );
1267        // Remove pruned rows from the visited rows set:
1268        for row in self.deleted_offset..(self.deleted_offset + prune_length) {
1269            self.visited_rows.remove(&row);
1270        }
1271        // Update the input buffer after pruning:
1272        self.input_buffer = self
1273            .input_buffer
1274            .slice(prune_length, self.input_buffer.num_rows() - prune_length);
1275        // Increment the deleted offset:
1276        self.deleted_offset += prune_length;
1277        Ok(())
1278    }
1279}
1280
1281/// `SymmetricHashJoinStream` manages incremental join operations between two
1282/// streams. Unlike traditional join approaches that need to scan one side of
1283/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates
1284/// more dynamic join operations by working with streams as they emit data. This
1285/// approach allows for more efficient processing, particularly in scenarios
1286/// where waiting for complete data materialization is not feasible or optimal.
1287/// The trait provides a framework for handling various states of such a join
1288/// process, ensuring that join logic is efficiently executed as data becomes
1289/// available from either stream.
1290///
1291/// This implementation performs eager joins of data from two different asynchronous
1292/// streams, typically referred to as left and right streams. The implementation
1293/// provides a comprehensive set of methods to control and execute the join
1294/// process, leveraging the states defined in `SHJStreamState`. Methods are
1295/// primarily focused on asynchronously fetching data batches from each stream,
1296/// processing them, and managing transitions between various states of the join.
1297///
1298/// This implementations use a state machine approach to navigate different
1299/// stages of the join operation, handling data from both streams and determining
1300/// when the join completes.
1301///
1302/// State Transitions:
1303/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
1304///   - In `fetch_next_from_left_stream`, when fetching a batch from the left stream:
1305///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
1306///       processing the batch.
1307///     - On error (`Some(Err(e))`), the error is returned, and the state remains
1308///       unchanged.
1309///     - On no data (`None`), state changes to `LeftExhausted`, returning `Continue`
1310///       to proceed with the join process.
1311/// - From `PullRight` to `PullLeft` or `RightExhausted`:
1312///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
1313///     - If a batch is available, state changes to `PullLeft` for processing.
1314///     - On error, the error is returned without changing the state.
1315///     - If right stream is exhausted (`None`), state transitions to `RightExhausted`,
1316///       with a `Continue` result.
1317/// - Handling `RightExhausted` and `LeftExhausted`:
1318///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios
1319///     when streams are exhausted:
1320///     - They attempt to continue processing with the other stream.
1321///     - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`.
1322/// - Transition to `BothExhausted { final_result: true }`:
1323///   - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are
1324///     exhausted, indicating completion of processing and availability of final results.
1325impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
1326    /// Implements the main polling logic for the join stream.
1327    ///
1328    /// This method continuously checks the state of the join stream and
1329    /// acts accordingly by delegating the handling to appropriate sub-methods
1330    /// depending on the current state.
1331    ///
1332    /// # Arguments
1333    ///
1334    /// * `cx` - A context that facilitates cooperative non-blocking execution within a task.
1335    ///
1336    /// # Returns
1337    ///
1338    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a `RecordBatch` or None.
1339    fn poll_next_impl(
1340        &mut self,
1341        cx: &mut Context<'_>,
1342    ) -> Poll<Option<Result<RecordBatch>>> {
1343        loop {
1344            match self.batch_transformer.next() {
1345                None => {
1346                    let result = match self.state() {
1347                        SHJStreamState::PullRight => {
1348                            ready!(self.fetch_next_from_right_stream(cx))
1349                        }
1350                        SHJStreamState::PullLeft => {
1351                            ready!(self.fetch_next_from_left_stream(cx))
1352                        }
1353                        SHJStreamState::RightExhausted => {
1354                            ready!(self.handle_right_stream_end(cx))
1355                        }
1356                        SHJStreamState::LeftExhausted => {
1357                            ready!(self.handle_left_stream_end(cx))
1358                        }
1359                        SHJStreamState::BothExhausted {
1360                            final_result: false,
1361                        } => self.prepare_for_final_results_after_exhaustion(),
1362                        SHJStreamState::BothExhausted { final_result: true } => {
1363                            return Poll::Ready(None);
1364                        }
1365                    };
1366
1367                    match result? {
1368                        StatefulStreamResult::Ready(None) => {
1369                            return Poll::Ready(None);
1370                        }
1371                        StatefulStreamResult::Ready(Some(batch)) => {
1372                            self.batch_transformer.set_batch(batch);
1373                        }
1374                        _ => {}
1375                    }
1376                }
1377                Some((batch, _)) => {
1378                    return self
1379                        .metrics
1380                        .baseline_metrics
1381                        .record_poll(Poll::Ready(Some(Ok(batch))));
1382                }
1383            }
1384        }
1385    }
1386    /// Asynchronously pulls the next batch from the right stream.
1387    ///
1388    /// This default implementation checks for the next value in the right stream.
1389    /// If a batch is found, the state is switched to `PullLeft`, and the batch handling
1390    /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`.
1391    ///
1392    /// # Returns
1393    ///
1394    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1395    fn fetch_next_from_right_stream(
1396        &mut self,
1397        cx: &mut Context<'_>,
1398    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1399        match ready!(self.right_stream().poll_next_unpin(cx)) {
1400            Some(Ok(batch)) => {
1401                if batch.num_rows() == 0 {
1402                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1403                }
1404                self.set_state(SHJStreamState::PullLeft);
1405                Poll::Ready(self.process_batch_from_right(&batch))
1406            }
1407            Some(Err(e)) => Poll::Ready(Err(e)),
1408            None => {
1409                self.set_state(SHJStreamState::RightExhausted);
1410                Poll::Ready(Ok(StatefulStreamResult::Continue))
1411            }
1412        }
1413    }
1414
1415    /// Asynchronously pulls the next batch from the left stream.
1416    ///
1417    /// This default implementation checks for the next value in the left stream.
1418    /// If a batch is found, the state is switched to `PullRight`, and the batch handling
1419    /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`.
1420    ///
1421    /// # Returns
1422    ///
1423    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1424    fn fetch_next_from_left_stream(
1425        &mut self,
1426        cx: &mut Context<'_>,
1427    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1428        match ready!(self.left_stream().poll_next_unpin(cx)) {
1429            Some(Ok(batch)) => {
1430                if batch.num_rows() == 0 {
1431                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1432                }
1433                self.set_state(SHJStreamState::PullRight);
1434                Poll::Ready(self.process_batch_from_left(&batch))
1435            }
1436            Some(Err(e)) => Poll::Ready(Err(e)),
1437            None => {
1438                self.set_state(SHJStreamState::LeftExhausted);
1439                Poll::Ready(Ok(StatefulStreamResult::Continue))
1440            }
1441        }
1442    }
1443
1444    /// Asynchronously handles the scenario when the right stream is exhausted.
1445    ///
1446    /// In this default implementation, when the right stream is exhausted, it attempts
1447    /// to pull from the left stream. If a batch is found in the left stream, it delegates
1448    /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set
1449    /// to indicate both streams are exhausted without final results yet.
1450    ///
1451    /// # Returns
1452    ///
1453    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1454    fn handle_right_stream_end(
1455        &mut self,
1456        cx: &mut Context<'_>,
1457    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1458        match ready!(self.left_stream().poll_next_unpin(cx)) {
1459            Some(Ok(batch)) => {
1460                if batch.num_rows() == 0 {
1461                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1462                }
1463                Poll::Ready(self.process_batch_after_right_end(&batch))
1464            }
1465            Some(Err(e)) => Poll::Ready(Err(e)),
1466            None => {
1467                self.set_state(SHJStreamState::BothExhausted {
1468                    final_result: false,
1469                });
1470                Poll::Ready(Ok(StatefulStreamResult::Continue))
1471            }
1472        }
1473    }
1474
1475    /// Asynchronously handles the scenario when the left stream is exhausted.
1476    ///
1477    /// When the left stream is exhausted, this default
1478    /// implementation tries to pull from the right stream and delegates the batch
1479    /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state
1480    /// is updated to indicate so.
1481    ///
1482    /// # Returns
1483    ///
1484    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1485    fn handle_left_stream_end(
1486        &mut self,
1487        cx: &mut Context<'_>,
1488    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1489        match ready!(self.right_stream().poll_next_unpin(cx)) {
1490            Some(Ok(batch)) => {
1491                if batch.num_rows() == 0 {
1492                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1493                }
1494                Poll::Ready(self.process_batch_after_left_end(&batch))
1495            }
1496            Some(Err(e)) => Poll::Ready(Err(e)),
1497            None => {
1498                self.set_state(SHJStreamState::BothExhausted {
1499                    final_result: false,
1500                });
1501                Poll::Ready(Ok(StatefulStreamResult::Continue))
1502            }
1503        }
1504    }
1505
1506    /// Handles the state when both streams are exhausted and final results are yet to be produced.
1507    ///
1508    /// This default implementation switches the state to indicate both streams are
1509    /// exhausted with final results and then invokes the handling for this specific
1510    /// scenario via `process_batches_before_finalization`.
1511    ///
1512    /// # Returns
1513    ///
1514    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
1515    fn prepare_for_final_results_after_exhaustion(
1516        &mut self,
1517    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1518        self.set_state(SHJStreamState::BothExhausted { final_result: true });
1519        self.process_batches_before_finalization()
1520    }
1521
1522    fn process_batch_from_right(
1523        &mut self,
1524        batch: &RecordBatch,
1525    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1526        self.perform_join_for_given_side(batch, JoinSide::Right)
1527            .map(|maybe_batch| {
1528                if maybe_batch.is_some() {
1529                    StatefulStreamResult::Ready(maybe_batch)
1530                } else {
1531                    StatefulStreamResult::Continue
1532                }
1533            })
1534    }
1535
1536    fn process_batch_from_left(
1537        &mut self,
1538        batch: &RecordBatch,
1539    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1540        self.perform_join_for_given_side(batch, JoinSide::Left)
1541            .map(|maybe_batch| {
1542                if maybe_batch.is_some() {
1543                    StatefulStreamResult::Ready(maybe_batch)
1544                } else {
1545                    StatefulStreamResult::Continue
1546                }
1547            })
1548    }
1549
1550    fn process_batch_after_left_end(
1551        &mut self,
1552        right_batch: &RecordBatch,
1553    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1554        self.process_batch_from_right(right_batch)
1555    }
1556
1557    fn process_batch_after_right_end(
1558        &mut self,
1559        left_batch: &RecordBatch,
1560    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1561        self.process_batch_from_left(left_batch)
1562    }
1563
1564    fn process_batches_before_finalization(
1565        &mut self,
1566    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1567        // Get the left side results:
1568        let left_result = build_side_determined_results(
1569            &self.left,
1570            &self.schema,
1571            self.left.input_buffer.num_rows(),
1572            self.right.input_buffer.schema(),
1573            self.join_type,
1574            &self.column_indices,
1575        )?;
1576        // Get the right side results:
1577        let right_result = build_side_determined_results(
1578            &self.right,
1579            &self.schema,
1580            self.right.input_buffer.num_rows(),
1581            self.left.input_buffer.schema(),
1582            self.join_type,
1583            &self.column_indices,
1584        )?;
1585
1586        // Combine the left and right results:
1587        let result = combine_two_batches(&self.schema, left_result, right_result)?;
1588
1589        // Return the result:
1590        if result.is_some() {
1591            return Ok(StatefulStreamResult::Ready(result));
1592        }
1593        Ok(StatefulStreamResult::Continue)
1594    }
1595
1596    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
1597        &mut self.right_stream
1598    }
1599
1600    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
1601        &mut self.left_stream
1602    }
1603
1604    fn set_state(&mut self, state: SHJStreamState) {
1605        self.state = state;
1606    }
1607
1608    fn state(&mut self) -> SHJStreamState {
1609        self.state.clone()
1610    }
1611
1612    fn size(&self) -> usize {
1613        let mut size = 0;
1614        size += size_of_val(&self.schema);
1615        size += size_of_val(&self.filter);
1616        size += size_of_val(&self.join_type);
1617        size += self.left.size();
1618        size += self.right.size();
1619        size += size_of_val(&self.column_indices);
1620        size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
1621        size += size_of_val(&self.left_sorted_filter_expr);
1622        size += size_of_val(&self.right_sorted_filter_expr);
1623        size += size_of_val(&self.random_state);
1624        size += size_of_val(&self.null_equality);
1625        size += size_of_val(&self.metrics);
1626        size
1627    }
1628
1629    /// Performs a join operation for the specified `probe_side` (either left or right).
1630    /// This function:
1631    /// 1. Determines which side is the probe and which is the build side.
1632    /// 2. Updates metrics based on the batch that was polled.
1633    /// 3. Executes the join with the given `probe_batch`.
1634    /// 4. Optionally computes anti-join results if all conditions are met.
1635    /// 5. Combines the results and returns a combined batch or `None` if no batch was produced.
1636    fn perform_join_for_given_side(
1637        &mut self,
1638        probe_batch: &RecordBatch,
1639        probe_side: JoinSide,
1640    ) -> Result<Option<RecordBatch>> {
1641        let (
1642            probe_hash_joiner,
1643            build_hash_joiner,
1644            probe_side_sorted_filter_expr,
1645            build_side_sorted_filter_expr,
1646            probe_side_metrics,
1647        ) = if probe_side.eq(&JoinSide::Left) {
1648            (
1649                &mut self.left,
1650                &mut self.right,
1651                &mut self.left_sorted_filter_expr,
1652                &mut self.right_sorted_filter_expr,
1653                &mut self.metrics.left,
1654            )
1655        } else {
1656            (
1657                &mut self.right,
1658                &mut self.left,
1659                &mut self.right_sorted_filter_expr,
1660                &mut self.left_sorted_filter_expr,
1661                &mut self.metrics.right,
1662            )
1663        };
1664        // Update the metrics for the stream that was polled:
1665        probe_side_metrics.input_batches.add(1);
1666        probe_side_metrics.input_rows.add(probe_batch.num_rows());
1667        // Update the internal state of the hash joiner for the build side:
1668        probe_hash_joiner.update_internal_state(probe_batch, &self.random_state)?;
1669        // Join the two sides:
1670        let equal_result = join_with_probe_batch(
1671            build_hash_joiner,
1672            probe_hash_joiner,
1673            &self.schema,
1674            self.join_type,
1675            self.filter.as_ref(),
1676            probe_batch,
1677            &self.column_indices,
1678            &self.random_state,
1679            self.null_equality,
1680        )?;
1681        // Increment the offset for the probe hash joiner:
1682        probe_hash_joiner.offset += probe_batch.num_rows();
1683
1684        let anti_result = if let (
1685            Some(build_side_sorted_filter_expr),
1686            Some(probe_side_sorted_filter_expr),
1687            Some(graph),
1688        ) = (
1689            build_side_sorted_filter_expr.as_mut(),
1690            probe_side_sorted_filter_expr.as_mut(),
1691            self.graph.as_mut(),
1692        ) {
1693            // Calculate filter intervals:
1694            calculate_filter_expr_intervals(
1695                &build_hash_joiner.input_buffer,
1696                build_side_sorted_filter_expr,
1697                probe_batch,
1698                probe_side_sorted_filter_expr,
1699            )?;
1700            let prune_length = build_hash_joiner
1701                .calculate_prune_length_with_probe_batch(
1702                    build_side_sorted_filter_expr,
1703                    probe_side_sorted_filter_expr,
1704                    graph,
1705                )?;
1706            let result = build_side_determined_results(
1707                build_hash_joiner,
1708                &self.schema,
1709                prune_length,
1710                probe_batch.schema(),
1711                self.join_type,
1712                &self.column_indices,
1713            )?;
1714            build_hash_joiner.prune_internal_state(prune_length)?;
1715            result
1716        } else {
1717            None
1718        };
1719
1720        // Combine results:
1721        let result = combine_two_batches(&self.schema, equal_result, anti_result)?;
1722        let capacity = self.size();
1723        self.metrics.stream_memory_usage.set(capacity);
1724        self.reservation.lock().try_resize(capacity)?;
1725        Ok(result)
1726    }
1727}
1728
1729/// Represents the various states of an symmetric hash join stream operation.
1730///
1731/// This enum is used to track the current state of streaming during a join
1732/// operation. It provides indicators as to which side of the join needs to be
1733/// pulled next or if one (or both) sides have been exhausted. This allows
1734/// for efficient management of resources and optimal performance during the
1735/// join process.
1736#[derive(Clone, Debug)]
1737pub enum SHJStreamState {
1738    /// Indicates that the next step should pull from the right side of the join.
1739    PullRight,
1740
1741    /// Indicates that the next step should pull from the left side of the join.
1742    PullLeft,
1743
1744    /// State representing that the right side of the join has been fully processed.
1745    RightExhausted,
1746
1747    /// State representing that the left side of the join has been fully processed.
1748    LeftExhausted,
1749
1750    /// Represents a state where both sides of the join are exhausted.
1751    ///
1752    /// The `final_result` field indicates whether the join operation has
1753    /// produced a final result or not.
1754    BothExhausted { final_result: bool },
1755}
1756
1757#[cfg(test)]
1758mod tests {
1759    use std::collections::HashMap;
1760    use std::sync::{LazyLock, Mutex};
1761
1762    use super::*;
1763    use crate::joins::test_utils::{
1764        build_sides_record_batches, compare_batches, complicated_filter,
1765        create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
1766        join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
1767        partitioned_sym_join_with_filter, split_record_batches,
1768    };
1769
1770    use arrow::compute::SortOptions;
1771    use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
1772    use datafusion_common::ScalarValue;
1773    use datafusion_execution::config::SessionConfig;
1774    use datafusion_expr::Operator;
1775    use datafusion_physical_expr::expressions::{Column, binary, col, lit};
1776    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
1777
1778    use rstest::*;
1779
1780    const TABLE_SIZE: i32 = 30;
1781
1782    type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size)
1783    type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); // (left, right)
1784
1785    // Cache for storing tables
1786    static TABLE_CACHE: LazyLock<Mutex<HashMap<TableKey, TableValue>>> =
1787        LazyLock::new(|| Mutex::new(HashMap::new()));
1788
1789    fn get_or_create_table(
1790        cardinality: (i32, i32),
1791        batch_size: usize,
1792    ) -> Result<TableValue> {
1793        {
1794            let cache = TABLE_CACHE.lock().unwrap();
1795            if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
1796                return Ok(table.clone());
1797            }
1798        }
1799
1800        // If not, create the table
1801        let (left_batch, right_batch) =
1802            build_sides_record_batches(TABLE_SIZE, cardinality)?;
1803
1804        let (left_partition, right_partition) = (
1805            split_record_batches(&left_batch, batch_size)?,
1806            split_record_batches(&right_batch, batch_size)?,
1807        );
1808
1809        // Lock the cache again and store the table
1810        let mut cache = TABLE_CACHE.lock().unwrap();
1811
1812        // Store the table in the cache
1813        cache.insert(
1814            (cardinality.0, cardinality.1, batch_size),
1815            (left_partition.clone(), right_partition.clone()),
1816        );
1817
1818        Ok((left_partition, right_partition))
1819    }
1820
1821    pub async fn experiment(
1822        left: Arc<dyn ExecutionPlan>,
1823        right: Arc<dyn ExecutionPlan>,
1824        filter: Option<JoinFilter>,
1825        join_type: JoinType,
1826        on: JoinOn,
1827        task_ctx: Arc<TaskContext>,
1828    ) -> Result<()> {
1829        let first_batches = partitioned_sym_join_with_filter(
1830            Arc::clone(&left),
1831            Arc::clone(&right),
1832            on.clone(),
1833            filter.clone(),
1834            &join_type,
1835            NullEquality::NullEqualsNothing,
1836            Arc::clone(&task_ctx),
1837        )
1838        .await?;
1839        let second_batches = partitioned_hash_join_with_filter(
1840            left,
1841            right,
1842            on,
1843            filter,
1844            &join_type,
1845            NullEquality::NullEqualsNothing,
1846            task_ctx,
1847        )
1848        .await?;
1849        compare_batches(&first_batches, &second_batches);
1850        Ok(())
1851    }
1852
1853    #[rstest]
1854    #[tokio::test(flavor = "multi_thread")]
1855    async fn complex_join_all_one_ascending_numeric(
1856        #[values(
1857            JoinType::Inner,
1858            JoinType::Left,
1859            JoinType::Right,
1860            JoinType::RightSemi,
1861            JoinType::LeftSemi,
1862            JoinType::LeftAnti,
1863            JoinType::LeftMark,
1864            JoinType::RightAnti,
1865            JoinType::RightMark,
1866            JoinType::Full
1867        )]
1868        join_type: JoinType,
1869        #[values(
1870        (4, 5),
1871        (12, 17),
1872        )]
1873        cardinality: (i32, i32),
1874    ) -> Result<()> {
1875        // a + b > c + 10 AND a + b < c + 100
1876        let task_ctx = Arc::new(TaskContext::default());
1877
1878        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
1879
1880        let left_schema = &left_partition[0].schema();
1881        let right_schema = &right_partition[0].schema();
1882
1883        let left_sorted = [PhysicalSortExpr {
1884            expr: binary(
1885                col("la1", left_schema)?,
1886                Operator::Plus,
1887                col("la2", left_schema)?,
1888                left_schema,
1889            )?,
1890            options: SortOptions::default(),
1891        }]
1892        .into();
1893        let right_sorted = [PhysicalSortExpr {
1894            expr: col("ra1", right_schema)?,
1895            options: SortOptions::default(),
1896        }]
1897        .into();
1898        let (left, right) = create_memory_table(
1899            left_partition,
1900            right_partition,
1901            vec![left_sorted],
1902            vec![right_sorted],
1903        )?;
1904
1905        let on = vec![(
1906            binary(
1907                col("lc1", left_schema)?,
1908                Operator::Plus,
1909                lit(ScalarValue::Int32(Some(1))),
1910                left_schema,
1911            )?,
1912            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1913        )];
1914
1915        let intermediate_schema = Schema::new(vec![
1916            Field::new("0", DataType::Int32, true),
1917            Field::new("1", DataType::Int32, true),
1918            Field::new("2", DataType::Int32, true),
1919        ]);
1920        let filter_expr = complicated_filter(&intermediate_schema)?;
1921        let column_indices = vec![
1922            ColumnIndex {
1923                index: left_schema.index_of("la1")?,
1924                side: JoinSide::Left,
1925            },
1926            ColumnIndex {
1927                index: left_schema.index_of("la2")?,
1928                side: JoinSide::Left,
1929            },
1930            ColumnIndex {
1931                index: right_schema.index_of("ra1")?,
1932                side: JoinSide::Right,
1933            },
1934        ];
1935        let filter =
1936            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1937
1938        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1939        Ok(())
1940    }
1941
1942    #[rstest]
1943    #[tokio::test(flavor = "multi_thread")]
1944    async fn join_all_one_ascending_numeric(
1945        #[values(
1946            JoinType::Inner,
1947            JoinType::Left,
1948            JoinType::Right,
1949            JoinType::RightSemi,
1950            JoinType::LeftSemi,
1951            JoinType::LeftAnti,
1952            JoinType::LeftMark,
1953            JoinType::RightAnti,
1954            JoinType::RightMark,
1955            JoinType::Full
1956        )]
1957        join_type: JoinType,
1958        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1959    ) -> Result<()> {
1960        let task_ctx = Arc::new(TaskContext::default());
1961        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1962
1963        let left_schema = &left_partition[0].schema();
1964        let right_schema = &right_partition[0].schema();
1965
1966        let left_sorted = [PhysicalSortExpr {
1967            expr: col("la1", left_schema)?,
1968            options: SortOptions::default(),
1969        }]
1970        .into();
1971        let right_sorted = [PhysicalSortExpr {
1972            expr: col("ra1", right_schema)?,
1973            options: SortOptions::default(),
1974        }]
1975        .into();
1976        let (left, right) = create_memory_table(
1977            left_partition,
1978            right_partition,
1979            vec![left_sorted],
1980            vec![right_sorted],
1981        )?;
1982
1983        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
1984
1985        let intermediate_schema = Schema::new(vec![
1986            Field::new("left", DataType::Int32, true),
1987            Field::new("right", DataType::Int32, true),
1988        ]);
1989        let filter_expr = join_expr_tests_fixture_i32(
1990            case_expr,
1991            col("left", &intermediate_schema)?,
1992            col("right", &intermediate_schema)?,
1993        );
1994        let column_indices = vec![
1995            ColumnIndex {
1996                index: 0,
1997                side: JoinSide::Left,
1998            },
1999            ColumnIndex {
2000                index: 0,
2001                side: JoinSide::Right,
2002            },
2003        ];
2004        let filter =
2005            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2006
2007        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2008        Ok(())
2009    }
2010
2011    #[rstest]
2012    #[tokio::test(flavor = "multi_thread")]
2013    async fn join_without_sort_information(
2014        #[values(
2015            JoinType::Inner,
2016            JoinType::Left,
2017            JoinType::Right,
2018            JoinType::RightSemi,
2019            JoinType::LeftSemi,
2020            JoinType::LeftAnti,
2021            JoinType::LeftMark,
2022            JoinType::RightAnti,
2023            JoinType::RightMark,
2024            JoinType::Full
2025        )]
2026        join_type: JoinType,
2027        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2028    ) -> Result<()> {
2029        let task_ctx = Arc::new(TaskContext::default());
2030        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
2031
2032        let left_schema = &left_partition[0].schema();
2033        let right_schema = &right_partition[0].schema();
2034        let (left, right) =
2035            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2036
2037        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2038
2039        let intermediate_schema = Schema::new(vec![
2040            Field::new("left", DataType::Int32, true),
2041            Field::new("right", DataType::Int32, true),
2042        ]);
2043        let filter_expr = join_expr_tests_fixture_i32(
2044            case_expr,
2045            col("left", &intermediate_schema)?,
2046            col("right", &intermediate_schema)?,
2047        );
2048        let column_indices = vec![
2049            ColumnIndex {
2050                index: 5,
2051                side: JoinSide::Left,
2052            },
2053            ColumnIndex {
2054                index: 5,
2055                side: JoinSide::Right,
2056            },
2057        ];
2058        let filter =
2059            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2060
2061        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2062        Ok(())
2063    }
2064
2065    #[rstest]
2066    #[tokio::test(flavor = "multi_thread")]
2067    async fn join_without_filter(
2068        #[values(
2069            JoinType::Inner,
2070            JoinType::Left,
2071            JoinType::Right,
2072            JoinType::RightSemi,
2073            JoinType::LeftSemi,
2074            JoinType::LeftAnti,
2075            JoinType::LeftMark,
2076            JoinType::RightAnti,
2077            JoinType::RightMark,
2078            JoinType::Full
2079        )]
2080        join_type: JoinType,
2081    ) -> Result<()> {
2082        let task_ctx = Arc::new(TaskContext::default());
2083        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2084        let left_schema = &left_partition[0].schema();
2085        let right_schema = &right_partition[0].schema();
2086        let (left, right) =
2087            create_memory_table(left_partition, right_partition, vec![], vec![])?;
2088
2089        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2090        experiment(left, right, None, join_type, on, task_ctx).await?;
2091        Ok(())
2092    }
2093
2094    #[rstest]
2095    #[tokio::test(flavor = "multi_thread")]
2096    async fn join_all_one_descending_numeric_particular(
2097        #[values(
2098            JoinType::Inner,
2099            JoinType::Left,
2100            JoinType::Right,
2101            JoinType::RightSemi,
2102            JoinType::LeftSemi,
2103            JoinType::LeftAnti,
2104            JoinType::LeftMark,
2105            JoinType::RightAnti,
2106            JoinType::RightMark,
2107            JoinType::Full
2108        )]
2109        join_type: JoinType,
2110        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2111    ) -> Result<()> {
2112        let task_ctx = Arc::new(TaskContext::default());
2113        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
2114        let left_schema = &left_partition[0].schema();
2115        let right_schema = &right_partition[0].schema();
2116        let left_sorted = [PhysicalSortExpr {
2117            expr: col("la1_des", left_schema)?,
2118            options: SortOptions {
2119                descending: true,
2120                nulls_first: true,
2121            },
2122        }]
2123        .into();
2124        let right_sorted = [PhysicalSortExpr {
2125            expr: col("ra1_des", right_schema)?,
2126            options: SortOptions {
2127                descending: true,
2128                nulls_first: true,
2129            },
2130        }]
2131        .into();
2132        let (left, right) = create_memory_table(
2133            left_partition,
2134            right_partition,
2135            vec![left_sorted],
2136            vec![right_sorted],
2137        )?;
2138
2139        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2140
2141        let intermediate_schema = Schema::new(vec![
2142            Field::new("left", DataType::Int32, true),
2143            Field::new("right", DataType::Int32, true),
2144        ]);
2145        let filter_expr = join_expr_tests_fixture_i32(
2146            case_expr,
2147            col("left", &intermediate_schema)?,
2148            col("right", &intermediate_schema)?,
2149        );
2150        let column_indices = vec![
2151            ColumnIndex {
2152                index: 5,
2153                side: JoinSide::Left,
2154            },
2155            ColumnIndex {
2156                index: 5,
2157                side: JoinSide::Right,
2158            },
2159        ];
2160        let filter =
2161            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2162
2163        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2164        Ok(())
2165    }
2166
2167    #[tokio::test(flavor = "multi_thread")]
2168    async fn build_null_columns_first() -> Result<()> {
2169        let join_type = JoinType::Full;
2170        let case_expr = 1;
2171        let session_config = SessionConfig::new().with_repartition_joins(false);
2172        let task_ctx = TaskContext::default().with_session_config(session_config);
2173        let task_ctx = Arc::new(task_ctx);
2174        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2175        let left_schema = &left_partition[0].schema();
2176        let right_schema = &right_partition[0].schema();
2177        let left_sorted = [PhysicalSortExpr {
2178            expr: col("l_asc_null_first", left_schema)?,
2179            options: SortOptions {
2180                descending: false,
2181                nulls_first: true,
2182            },
2183        }]
2184        .into();
2185        let right_sorted = [PhysicalSortExpr {
2186            expr: col("r_asc_null_first", right_schema)?,
2187            options: SortOptions {
2188                descending: false,
2189                nulls_first: true,
2190            },
2191        }]
2192        .into();
2193        let (left, right) = create_memory_table(
2194            left_partition,
2195            right_partition,
2196            vec![left_sorted],
2197            vec![right_sorted],
2198        )?;
2199
2200        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2201
2202        let intermediate_schema = Schema::new(vec![
2203            Field::new("left", DataType::Int32, true),
2204            Field::new("right", DataType::Int32, true),
2205        ]);
2206        let filter_expr = join_expr_tests_fixture_i32(
2207            case_expr,
2208            col("left", &intermediate_schema)?,
2209            col("right", &intermediate_schema)?,
2210        );
2211        let column_indices = vec![
2212            ColumnIndex {
2213                index: 6,
2214                side: JoinSide::Left,
2215            },
2216            ColumnIndex {
2217                index: 6,
2218                side: JoinSide::Right,
2219            },
2220        ];
2221        let filter =
2222            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2223        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2224        Ok(())
2225    }
2226
2227    #[tokio::test(flavor = "multi_thread")]
2228    async fn build_null_columns_last() -> Result<()> {
2229        let join_type = JoinType::Full;
2230        let case_expr = 1;
2231        let session_config = SessionConfig::new().with_repartition_joins(false);
2232        let task_ctx = TaskContext::default().with_session_config(session_config);
2233        let task_ctx = Arc::new(task_ctx);
2234        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?;
2235
2236        let left_schema = &left_partition[0].schema();
2237        let right_schema = &right_partition[0].schema();
2238        let left_sorted = [PhysicalSortExpr {
2239            expr: col("l_asc_null_last", left_schema)?,
2240            options: SortOptions {
2241                descending: false,
2242                nulls_first: false,
2243            },
2244        }]
2245        .into();
2246        let right_sorted = [PhysicalSortExpr {
2247            expr: col("r_asc_null_last", right_schema)?,
2248            options: SortOptions {
2249                descending: false,
2250                nulls_first: false,
2251            },
2252        }]
2253        .into();
2254        let (left, right) = create_memory_table(
2255            left_partition,
2256            right_partition,
2257            vec![left_sorted],
2258            vec![right_sorted],
2259        )?;
2260
2261        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2262
2263        let intermediate_schema = Schema::new(vec![
2264            Field::new("left", DataType::Int32, true),
2265            Field::new("right", DataType::Int32, true),
2266        ]);
2267        let filter_expr = join_expr_tests_fixture_i32(
2268            case_expr,
2269            col("left", &intermediate_schema)?,
2270            col("right", &intermediate_schema)?,
2271        );
2272        let column_indices = vec![
2273            ColumnIndex {
2274                index: 7,
2275                side: JoinSide::Left,
2276            },
2277            ColumnIndex {
2278                index: 7,
2279                side: JoinSide::Right,
2280            },
2281        ];
2282        let filter =
2283            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2284
2285        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2286        Ok(())
2287    }
2288
2289    #[tokio::test(flavor = "multi_thread")]
2290    async fn build_null_columns_first_descending() -> Result<()> {
2291        let join_type = JoinType::Full;
2292        let cardinality = (10, 11);
2293        let case_expr = 1;
2294        let session_config = SessionConfig::new().with_repartition_joins(false);
2295        let task_ctx = TaskContext::default().with_session_config(session_config);
2296        let task_ctx = Arc::new(task_ctx);
2297        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2298
2299        let left_schema = &left_partition[0].schema();
2300        let right_schema = &right_partition[0].schema();
2301        let left_sorted = [PhysicalSortExpr {
2302            expr: col("l_desc_null_first", left_schema)?,
2303            options: SortOptions {
2304                descending: true,
2305                nulls_first: true,
2306            },
2307        }]
2308        .into();
2309        let right_sorted = [PhysicalSortExpr {
2310            expr: col("r_desc_null_first", right_schema)?,
2311            options: SortOptions {
2312                descending: true,
2313                nulls_first: true,
2314            },
2315        }]
2316        .into();
2317        let (left, right) = create_memory_table(
2318            left_partition,
2319            right_partition,
2320            vec![left_sorted],
2321            vec![right_sorted],
2322        )?;
2323
2324        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2325
2326        let intermediate_schema = Schema::new(vec![
2327            Field::new("left", DataType::Int32, true),
2328            Field::new("right", DataType::Int32, true),
2329        ]);
2330        let filter_expr = join_expr_tests_fixture_i32(
2331            case_expr,
2332            col("left", &intermediate_schema)?,
2333            col("right", &intermediate_schema)?,
2334        );
2335        let column_indices = vec![
2336            ColumnIndex {
2337                index: 8,
2338                side: JoinSide::Left,
2339            },
2340            ColumnIndex {
2341                index: 8,
2342                side: JoinSide::Right,
2343            },
2344        ];
2345        let filter =
2346            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2347
2348        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2349        Ok(())
2350    }
2351
2352    #[tokio::test(flavor = "multi_thread")]
2353    async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
2354        let cardinality = (3, 4);
2355        let join_type = JoinType::Full;
2356
2357        // a + b > c + 10 AND a + b < c + 100
2358        let session_config = SessionConfig::new().with_repartition_joins(false);
2359        let task_ctx = TaskContext::default().with_session_config(session_config);
2360        let task_ctx = Arc::new(task_ctx);
2361        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2362
2363        let left_schema = &left_partition[0].schema();
2364        let right_schema = &right_partition[0].schema();
2365        let left_sorted = [PhysicalSortExpr {
2366            expr: col("la1", left_schema)?,
2367            options: SortOptions::default(),
2368        }]
2369        .into();
2370        let right_sorted = [PhysicalSortExpr {
2371            expr: col("ra1", right_schema)?,
2372            options: SortOptions::default(),
2373        }]
2374        .into();
2375        let (left, right) = create_memory_table(
2376            left_partition,
2377            right_partition,
2378            vec![left_sorted],
2379            vec![right_sorted],
2380        )?;
2381
2382        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2383
2384        let intermediate_schema = Schema::new(vec![
2385            Field::new("0", DataType::Int32, true),
2386            Field::new("1", DataType::Int32, true),
2387            Field::new("2", DataType::Int32, true),
2388        ]);
2389        let filter_expr = complicated_filter(&intermediate_schema)?;
2390        let column_indices = vec![
2391            ColumnIndex {
2392                index: 0,
2393                side: JoinSide::Left,
2394            },
2395            ColumnIndex {
2396                index: 4,
2397                side: JoinSide::Left,
2398            },
2399            ColumnIndex {
2400                index: 0,
2401                side: JoinSide::Right,
2402            },
2403        ];
2404        let filter =
2405            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2406
2407        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2408        Ok(())
2409    }
2410
2411    #[tokio::test(flavor = "multi_thread")]
2412    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
2413        let cardinality = (3, 4);
2414        let join_type = JoinType::Full;
2415
2416        // a + b > c + 10 AND a + b < c + 100
2417        let config = SessionConfig::new().with_repartition_joins(false);
2418        // let session_ctx = SessionContext::with_config(config);
2419        // let task_ctx = session_ctx.task_ctx();
2420        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
2421        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2422        let left_schema = &left_partition[0].schema();
2423        let right_schema = &right_partition[0].schema();
2424        let left_sorted = vec![
2425            [PhysicalSortExpr {
2426                expr: col("la1", left_schema)?,
2427                options: SortOptions::default(),
2428            }]
2429            .into(),
2430            [PhysicalSortExpr {
2431                expr: col("la2", left_schema)?,
2432                options: SortOptions::default(),
2433            }]
2434            .into(),
2435        ];
2436
2437        let right_sorted = [PhysicalSortExpr {
2438            expr: col("ra1", right_schema)?,
2439            options: SortOptions::default(),
2440        }]
2441        .into();
2442
2443        let (left, right) = create_memory_table(
2444            left_partition,
2445            right_partition,
2446            left_sorted,
2447            vec![right_sorted],
2448        )?;
2449
2450        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2451
2452        let intermediate_schema = Schema::new(vec![
2453            Field::new("0", DataType::Int32, true),
2454            Field::new("1", DataType::Int32, true),
2455            Field::new("2", DataType::Int32, true),
2456        ]);
2457        let filter_expr = complicated_filter(&intermediate_schema)?;
2458        let column_indices = vec![
2459            ColumnIndex {
2460                index: 0,
2461                side: JoinSide::Left,
2462            },
2463            ColumnIndex {
2464                index: 4,
2465                side: JoinSide::Left,
2466            },
2467            ColumnIndex {
2468                index: 0,
2469                side: JoinSide::Right,
2470            },
2471        ];
2472        let filter =
2473            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2474
2475        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2476        Ok(())
2477    }
2478
2479    #[rstest]
2480    #[tokio::test(flavor = "multi_thread")]
2481    async fn testing_with_temporal_columns(
2482        #[values(
2483            JoinType::Inner,
2484            JoinType::Left,
2485            JoinType::Right,
2486            JoinType::RightSemi,
2487            JoinType::LeftSemi,
2488            JoinType::LeftAnti,
2489            JoinType::LeftMark,
2490            JoinType::RightAnti,
2491            JoinType::RightMark,
2492            JoinType::Full
2493        )]
2494        join_type: JoinType,
2495        #[values(
2496            (4, 5),
2497            (12, 17),
2498        )]
2499        cardinality: (i32, i32),
2500        #[values(0, 1, 2)] case_expr: usize,
2501    ) -> Result<()> {
2502        let session_config = SessionConfig::new().with_repartition_joins(false);
2503        let task_ctx = TaskContext::default().with_session_config(session_config);
2504        let task_ctx = Arc::new(task_ctx);
2505        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2506
2507        let left_schema = &left_partition[0].schema();
2508        let right_schema = &right_partition[0].schema();
2509        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2510        let left_sorted = [PhysicalSortExpr {
2511            expr: col("lt1", left_schema)?,
2512            options: SortOptions {
2513                descending: false,
2514                nulls_first: true,
2515            },
2516        }]
2517        .into();
2518        let right_sorted = [PhysicalSortExpr {
2519            expr: col("rt1", right_schema)?,
2520            options: SortOptions {
2521                descending: false,
2522                nulls_first: true,
2523            },
2524        }]
2525        .into();
2526        let (left, right) = create_memory_table(
2527            left_partition,
2528            right_partition,
2529            vec![left_sorted],
2530            vec![right_sorted],
2531        )?;
2532        let intermediate_schema = Schema::new(vec![
2533            Field::new(
2534                "left",
2535                DataType::Timestamp(TimeUnit::Millisecond, None),
2536                false,
2537            ),
2538            Field::new(
2539                "right",
2540                DataType::Timestamp(TimeUnit::Millisecond, None),
2541                false,
2542            ),
2543        ]);
2544        let filter_expr = join_expr_tests_fixture_temporal(
2545            case_expr,
2546            col("left", &intermediate_schema)?,
2547            col("right", &intermediate_schema)?,
2548            &intermediate_schema,
2549        )?;
2550        let column_indices = vec![
2551            ColumnIndex {
2552                index: 3,
2553                side: JoinSide::Left,
2554            },
2555            ColumnIndex {
2556                index: 3,
2557                side: JoinSide::Right,
2558            },
2559        ];
2560        let filter =
2561            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2562        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2563        Ok(())
2564    }
2565
2566    #[rstest]
2567    #[tokio::test(flavor = "multi_thread")]
2568    async fn test_with_interval_columns(
2569        #[values(
2570            JoinType::Inner,
2571            JoinType::Left,
2572            JoinType::Right,
2573            JoinType::RightSemi,
2574            JoinType::LeftSemi,
2575            JoinType::LeftAnti,
2576            JoinType::LeftMark,
2577            JoinType::RightAnti,
2578            JoinType::RightMark,
2579            JoinType::Full
2580        )]
2581        join_type: JoinType,
2582        #[values(
2583            (4, 5),
2584            (12, 17),
2585        )]
2586        cardinality: (i32, i32),
2587    ) -> Result<()> {
2588        let session_config = SessionConfig::new().with_repartition_joins(false);
2589        let task_ctx = TaskContext::default().with_session_config(session_config);
2590        let task_ctx = Arc::new(task_ctx);
2591        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2592
2593        let left_schema = &left_partition[0].schema();
2594        let right_schema = &right_partition[0].schema();
2595        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2596        let left_sorted = [PhysicalSortExpr {
2597            expr: col("li1", left_schema)?,
2598            options: SortOptions {
2599                descending: false,
2600                nulls_first: true,
2601            },
2602        }]
2603        .into();
2604        let right_sorted = [PhysicalSortExpr {
2605            expr: col("ri1", right_schema)?,
2606            options: SortOptions {
2607                descending: false,
2608                nulls_first: true,
2609            },
2610        }]
2611        .into();
2612        let (left, right) = create_memory_table(
2613            left_partition,
2614            right_partition,
2615            vec![left_sorted],
2616            vec![right_sorted],
2617        )?;
2618        let intermediate_schema = Schema::new(vec![
2619            Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
2620            Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
2621        ]);
2622        let filter_expr = join_expr_tests_fixture_temporal(
2623            0,
2624            col("left", &intermediate_schema)?,
2625            col("right", &intermediate_schema)?,
2626            &intermediate_schema,
2627        )?;
2628        let column_indices = vec![
2629            ColumnIndex {
2630                index: 9,
2631                side: JoinSide::Left,
2632            },
2633            ColumnIndex {
2634                index: 9,
2635                side: JoinSide::Right,
2636            },
2637        ];
2638        let filter =
2639            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2640        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2641
2642        Ok(())
2643    }
2644
2645    #[rstest]
2646    #[tokio::test(flavor = "multi_thread")]
2647    async fn testing_ascending_float_pruning(
2648        #[values(
2649            JoinType::Inner,
2650            JoinType::Left,
2651            JoinType::Right,
2652            JoinType::RightSemi,
2653            JoinType::LeftSemi,
2654            JoinType::LeftAnti,
2655            JoinType::LeftMark,
2656            JoinType::RightAnti,
2657            JoinType::RightMark,
2658            JoinType::Full
2659        )]
2660        join_type: JoinType,
2661        #[values(
2662            (4, 5),
2663            (12, 17),
2664        )]
2665        cardinality: (i32, i32),
2666        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2667    ) -> Result<()> {
2668        let session_config = SessionConfig::new().with_repartition_joins(false);
2669        let task_ctx = TaskContext::default().with_session_config(session_config);
2670        let task_ctx = Arc::new(task_ctx);
2671        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2672
2673        let left_schema = &left_partition[0].schema();
2674        let right_schema = &right_partition[0].schema();
2675        let left_sorted = [PhysicalSortExpr {
2676            expr: col("l_float", left_schema)?,
2677            options: SortOptions::default(),
2678        }]
2679        .into();
2680        let right_sorted = [PhysicalSortExpr {
2681            expr: col("r_float", right_schema)?,
2682            options: SortOptions::default(),
2683        }]
2684        .into();
2685        let (left, right) = create_memory_table(
2686            left_partition,
2687            right_partition,
2688            vec![left_sorted],
2689            vec![right_sorted],
2690        )?;
2691
2692        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
2693
2694        let intermediate_schema = Schema::new(vec![
2695            Field::new("left", DataType::Float64, true),
2696            Field::new("right", DataType::Float64, true),
2697        ]);
2698        let filter_expr = join_expr_tests_fixture_f64(
2699            case_expr,
2700            col("left", &intermediate_schema)?,
2701            col("right", &intermediate_schema)?,
2702        );
2703        let column_indices = vec![
2704            ColumnIndex {
2705                index: 10, // l_float
2706                side: JoinSide::Left,
2707            },
2708            ColumnIndex {
2709                index: 10, // r_float
2710                side: JoinSide::Right,
2711            },
2712        ];
2713        let filter =
2714            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
2715
2716        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2717        Ok(())
2718    }
2719}