Skip to main content

datafusion_physical_plan/joins/sort_merge_join/
exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the Sort-Merge join execution plan.
19//! A Sort-Merge join plan consumes two sorted children plans and produces
20//! joined output by given join type and other options.
21
22use std::fmt::Formatter;
23use std::sync::Arc;
24
25use super::bitwise_stream::BitwiseSortMergeJoinStream;
26use super::materializing_stream::MaterializingSortMergeJoinStream;
27use super::metrics::SortMergeJoinMetrics;
28use crate::execution_plan::{EmissionType, boundedness_from_children};
29use crate::expressions::PhysicalSortExpr;
30use crate::joins::utils::{
31    JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid,
32    estimate_join_statistics, reorder_output_after_swap,
33    symmetric_join_output_partitioning,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet, SpillMetrics};
36use crate::projection::{
37    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
38    physical_to_column_exprs, update_join_on,
39};
40use crate::spill::spill_manager::SpillManager;
41use crate::{
42    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
43    PlanProperties, SendableRecordBatchStream, Statistics, check_if_same_properties,
44};
45
46use arrow::compute::SortOptions;
47use arrow::datatypes::SchemaRef;
48use datafusion_common::{
49    JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err,
50    plan_err,
51};
52use datafusion_execution::TaskContext;
53use datafusion_execution::memory_pool::MemoryConsumer;
54use datafusion_physical_expr::equivalence::join_equivalence_properties;
55use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
56use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
57
58/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge
59/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large
60/// inputs where one or both of the inputs don't fit in the available memory.
61///
62/// # Join Expressions
63///
64/// Equi-join predicate (e.g. `<col1> = <col2>`) expressions are represented by [`Self::on`].
65///
66/// Non-equality predicates, which can not be pushed down to join inputs (e.g.
67/// `<col1> != <col2>`) are known as "filter expressions" and are evaluated
68/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional
69/// expressions.
70///
71/// # Sorting
72///
73/// Assumes that both the left and right input to the join are pre-sorted. It is not the
74/// responsibility of this execution plan to sort the inputs.
75///
76/// # "Streamed" vs "Buffered"
77///
78/// The number of record batches of streamed input currently present in the memory will depend
79/// on the output batch size of the execution plan. There is no spilling support for streamed input.
80/// The comparisons are performed from values of join keys in streamed input with the values of
81/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in
82/// buffered input batches. The streamed input is managed through the states in `StreamedState`
83/// and streamed input batches are represented by `StreamedBatch`.
84///
85/// Buffered input is buffered for all record batches having the same value of join key.
86/// If the memory limit increases beyond the specified value and spilling is enabled,
87/// buffered batches could be spilled to disk. If spilling is disabled, the execution
88/// will fail under the same conditions. Multiple record batches of buffered could currently reside
89/// in memory/disk during the execution. The number of buffered batches residing in
90/// memory/disk depends on the number of rows of buffered input having the same value
91/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs,
92/// the algorithm understands when it is not needed anymore, and releases the buffered batches
93/// from memory/disk. The buffered input is managed through the states in `BufferedState`
94/// and buffered input batches are represented by `BufferedBatch`.
95///
96/// Depending on the type of join, left or right input may be selected as streamed or buffered
97/// respectively. For example, in a left-outer join, the left execution plan will be selected as
98/// streamed input while in a right-outer join, the right execution plan will be selected as the
99/// streamed input.
100///
101/// Reference for the algorithm:
102/// <https://en.wikipedia.org/wiki/Sort-merge_join>.
103///
104/// Helpful short video demonstration:
105/// <https://www.youtube.com/watch?v=jiWCPJtDE2c>.
106#[derive(Debug, Clone)]
107pub struct SortMergeJoinExec {
108    /// Left sorted joining execution plan
109    pub left: Arc<dyn ExecutionPlan>,
110    /// Right sorting joining execution plan
111    pub right: Arc<dyn ExecutionPlan>,
112    /// Set of common columns used to join on
113    pub on: JoinOn,
114    /// Filters which are applied while finding matching rows
115    pub filter: Option<JoinFilter>,
116    /// How the join is performed
117    pub join_type: JoinType,
118    /// The schema once the join is applied
119    schema: SchemaRef,
120    /// Execution metrics
121    metrics: ExecutionPlanMetricsSet,
122    /// The left SortExpr
123    left_sort_exprs: LexOrdering,
124    /// The right SortExpr
125    right_sort_exprs: LexOrdering,
126    /// Sort options of join columns used in sorting left and right execution plans
127    pub sort_options: Vec<SortOptions>,
128    /// Defines the null equality for the join.
129    pub null_equality: NullEquality,
130    /// Cache holding plan properties like equivalences, output partitioning etc.
131    cache: Arc<PlanProperties>,
132}
133
134impl SortMergeJoinExec {
135    /// Tries to create a new [SortMergeJoinExec].
136    /// The inputs are sorted using `sort_options` are applied to the columns in the `on`
137    /// # Error
138    /// This function errors when it is not possible to join the left and right sides on keys `on`.
139    pub fn try_new(
140        left: Arc<dyn ExecutionPlan>,
141        right: Arc<dyn ExecutionPlan>,
142        on: JoinOn,
143        filter: Option<JoinFilter>,
144        join_type: JoinType,
145        sort_options: Vec<SortOptions>,
146        null_equality: NullEquality,
147    ) -> Result<Self> {
148        let left_schema = left.schema();
149        let right_schema = right.schema();
150
151        check_join_is_valid(&left_schema, &right_schema, &on)?;
152        if sort_options.len() != on.len() {
153            return plan_err!(
154                "Expected number of sort options: {}, actual: {}",
155                on.len(),
156                sort_options.len()
157            );
158        }
159
160        let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
161            .iter()
162            .zip(sort_options.iter())
163            .map(|((l, r), sort_op)| {
164                let left = PhysicalSortExpr {
165                    expr: Arc::clone(l),
166                    options: *sort_op,
167                };
168                let right = PhysicalSortExpr {
169                    expr: Arc::clone(r),
170                    options: *sort_op,
171                };
172                (left, right)
173            })
174            .unzip();
175        let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
176            return plan_err!(
177                "SortMergeJoinExec requires valid sort expressions for its left side"
178            );
179        };
180        let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
181            return plan_err!(
182                "SortMergeJoinExec requires valid sort expressions for its right side"
183            );
184        };
185
186        let schema =
187            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
188        let cache =
189            Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
190        Ok(Self {
191            left,
192            right,
193            on,
194            filter,
195            join_type,
196            schema,
197            metrics: ExecutionPlanMetricsSet::new(),
198            left_sort_exprs,
199            right_sort_exprs,
200            sort_options,
201            null_equality,
202            cache: Arc::new(cache),
203        })
204    }
205
206    /// Get probe side (e.g streaming side) information for this sort merge join.
207    /// In current implementation, probe side is determined according to join type.
208    pub fn probe_side(join_type: &JoinType) -> JoinSide {
209        // When output schema contains only the right side, probe side is right.
210        // Otherwise probe side is the left side.
211        match join_type {
212            // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226)
213            JoinType::Right
214            | JoinType::RightSemi
215            | JoinType::RightAnti
216            | JoinType::RightMark => JoinSide::Right,
217            JoinType::Inner
218            | JoinType::Left
219            | JoinType::Full
220            | JoinType::LeftAnti
221            | JoinType::LeftSemi
222            | JoinType::LeftMark => JoinSide::Left,
223        }
224    }
225
226    /// Calculate order preservation flags for this sort merge join.
227    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
228        match join_type {
229            JoinType::Inner => vec![true, false],
230            JoinType::Left
231            | JoinType::LeftSemi
232            | JoinType::LeftAnti
233            | JoinType::LeftMark => vec![true, false],
234            JoinType::Right
235            | JoinType::RightSemi
236            | JoinType::RightAnti
237            | JoinType::RightMark => {
238                vec![false, true]
239            }
240            _ => vec![false, false],
241        }
242    }
243
244    /// Set of common columns used to join on
245    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
246        &self.on
247    }
248
249    /// Ref to right execution plan
250    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
251        &self.right
252    }
253
254    /// Join type
255    pub fn join_type(&self) -> JoinType {
256        self.join_type
257    }
258
259    /// Ref to left execution plan
260    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
261        &self.left
262    }
263
264    /// Ref to join filter
265    pub fn filter(&self) -> &Option<JoinFilter> {
266        &self.filter
267    }
268
269    /// Ref to sort options
270    pub fn sort_options(&self) -> &[SortOptions] {
271        &self.sort_options
272    }
273
274    /// Null equality
275    pub fn null_equality(&self) -> NullEquality {
276        self.null_equality
277    }
278
279    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
280    fn compute_properties(
281        left: &Arc<dyn ExecutionPlan>,
282        right: &Arc<dyn ExecutionPlan>,
283        schema: SchemaRef,
284        join_type: JoinType,
285        join_on: JoinOnRef,
286    ) -> Result<PlanProperties> {
287        // Calculate equivalence properties:
288        let eq_properties = join_equivalence_properties(
289            left.equivalence_properties().clone(),
290            right.equivalence_properties().clone(),
291            &join_type,
292            schema,
293            &Self::maintains_input_order(join_type),
294            Some(Self::probe_side(&join_type)),
295            join_on,
296        )?;
297
298        let output_partitioning =
299            symmetric_join_output_partitioning(left, right, &join_type)?;
300
301        Ok(PlanProperties::new(
302            eq_properties,
303            output_partitioning,
304            EmissionType::Incremental,
305            boundedness_from_children([left, right]),
306        ))
307    }
308
309    /// # Notes:
310    ///
311    /// This function should be called BEFORE inserting any repartitioning
312    /// operators on the join's children. Check [`super::super::HashJoinExec::swap_inputs`]
313    /// for more details.
314    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
315        let left = self.left();
316        let right = self.right();
317        let new_join = SortMergeJoinExec::try_new(
318            Arc::clone(right),
319            Arc::clone(left),
320            self.on()
321                .iter()
322                .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
323                .collect::<Vec<_>>(),
324            self.filter().as_ref().map(JoinFilter::swap),
325            self.join_type().swap(),
326            self.sort_options.clone(),
327            self.null_equality,
328        )?;
329
330        // TODO: OR this condition with having a built-in projection (like
331        //       ordinary hash join) when we support it.
332        if matches!(
333            self.join_type(),
334            JoinType::LeftSemi
335                | JoinType::RightSemi
336                | JoinType::LeftAnti
337                | JoinType::RightAnti
338                | JoinType::LeftMark
339                | JoinType::RightMark
340        ) {
341            Ok(Arc::new(new_join))
342        } else {
343            reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
344        }
345    }
346
347    fn with_new_children_and_same_properties(
348        &self,
349        mut children: Vec<Arc<dyn ExecutionPlan>>,
350    ) -> Self {
351        let left = children.swap_remove(0);
352        let right = children.swap_remove(0);
353        Self {
354            left,
355            right,
356            metrics: ExecutionPlanMetricsSet::new(),
357            ..Self::clone(self)
358        }
359    }
360}
361
362impl DisplayAs for SortMergeJoinExec {
363    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
364        match t {
365            DisplayFormatType::Default | DisplayFormatType::Verbose => {
366                let on = self
367                    .on
368                    .iter()
369                    .map(|(c1, c2)| format!("({c1}, {c2})"))
370                    .collect::<Vec<String>>()
371                    .join(", ");
372                let display_null_equality =
373                    if self.null_equality() == NullEquality::NullEqualsNull {
374                        ", NullsEqual: true"
375                    } else {
376                        ""
377                    };
378                write!(
379                    f,
380                    "{}: join_type={:?}, on=[{}]{}{}",
381                    Self::static_name(),
382                    self.join_type,
383                    on,
384                    self.filter.as_ref().map_or_else(
385                        || "".to_string(),
386                        |f| format!(", filter={}", f.expression())
387                    ),
388                    display_null_equality,
389                )
390            }
391            DisplayFormatType::TreeRender => {
392                let on = self
393                    .on
394                    .iter()
395                    .map(|(c1, c2)| {
396                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
397                    })
398                    .collect::<Vec<String>>()
399                    .join(", ");
400
401                if self.join_type() != JoinType::Inner {
402                    writeln!(f, "join_type={:?}", self.join_type)?;
403                }
404                writeln!(f, "on={on}")?;
405
406                if self.null_equality() == NullEquality::NullEqualsNull {
407                    writeln!(f, "NullsEqual: true")?;
408                }
409
410                Ok(())
411            }
412        }
413    }
414}
415
416impl ExecutionPlan for SortMergeJoinExec {
417    fn name(&self) -> &'static str {
418        "SortMergeJoinExec"
419    }
420
421    fn properties(&self) -> &Arc<PlanProperties> {
422        &self.cache
423    }
424
425    fn required_input_distribution(&self) -> Vec<Distribution> {
426        let (left_expr, right_expr) = self
427            .on
428            .iter()
429            .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
430            .unzip();
431        vec![
432            Distribution::HashPartitioned(left_expr),
433            Distribution::HashPartitioned(right_expr),
434        ]
435    }
436
437    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
438        vec![
439            Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
440            Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
441        ]
442    }
443
444    fn maintains_input_order(&self) -> Vec<bool> {
445        Self::maintains_input_order(self.join_type)
446    }
447
448    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
449        vec![&self.left, &self.right]
450    }
451
452    fn with_new_children(
453        self: Arc<Self>,
454        children: Vec<Arc<dyn ExecutionPlan>>,
455    ) -> Result<Arc<dyn ExecutionPlan>> {
456        check_if_same_properties!(self, children);
457        match &children[..] {
458            [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
459                Arc::clone(left),
460                Arc::clone(right),
461                self.on.clone(),
462                self.filter.clone(),
463                self.join_type,
464                self.sort_options.clone(),
465                self.null_equality,
466            )?)),
467            _ => internal_err!("SortMergeJoin wrong number of children"),
468        }
469    }
470
471    fn execute(
472        &self,
473        partition: usize,
474        context: Arc<TaskContext>,
475    ) -> Result<SendableRecordBatchStream> {
476        let left_partitions = self.left.output_partitioning().partition_count();
477        let right_partitions = self.right.output_partitioning().partition_count();
478        assert_eq_or_internal_err!(
479            left_partitions,
480            right_partitions,
481            "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
482                 consider using RepartitionExec"
483        );
484        let (on_left, on_right) = self.on.iter().cloned().unzip();
485        let (streamed, buffered, on_streamed, on_buffered) =
486            if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
487                (
488                    Arc::clone(&self.left),
489                    Arc::clone(&self.right),
490                    on_left,
491                    on_right,
492                )
493            } else {
494                (
495                    Arc::clone(&self.right),
496                    Arc::clone(&self.left),
497                    on_right,
498                    on_left,
499                )
500            };
501
502        // execute children plans
503        let streamed = streamed.execute(partition, Arc::clone(&context))?;
504        let buffered = buffered.execute(partition, Arc::clone(&context))?;
505
506        let batch_size = context.session_config().batch_size();
507        let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
508            .register(context.memory_pool());
509        let spill_manager = SpillManager::new(
510            context.runtime_env(),
511            SpillMetrics::new(&self.metrics, partition),
512            buffered.schema(),
513        )
514        .with_compression_type(context.session_config().spill_compression());
515
516        if matches!(
517            self.join_type,
518            JoinType::LeftSemi
519                | JoinType::LeftAnti
520                | JoinType::RightSemi
521                | JoinType::RightAnti
522                | JoinType::LeftMark
523                | JoinType::RightMark
524        ) {
525            Ok(Box::pin(BitwiseSortMergeJoinStream::try_new(
526                Arc::clone(&self.schema),
527                self.sort_options.clone(),
528                self.null_equality,
529                streamed,
530                buffered,
531                on_streamed,
532                on_buffered,
533                self.filter.clone(),
534                self.join_type,
535                batch_size,
536                partition,
537                &self.metrics,
538                reservation,
539                spill_manager,
540                context.runtime_env(),
541            )?))
542        } else {
543            Ok(Box::pin(MaterializingSortMergeJoinStream::try_new(
544                Arc::clone(&self.schema),
545                self.sort_options.clone(),
546                self.null_equality,
547                streamed,
548                buffered,
549                on_streamed,
550                on_buffered,
551                self.filter.clone(),
552                self.join_type,
553                batch_size,
554                SortMergeJoinMetrics::new(partition, &self.metrics),
555                reservation,
556                spill_manager,
557                context.runtime_env(),
558            )?))
559        }
560    }
561
562    fn metrics(&self) -> Option<MetricsSet> {
563        Some(self.metrics.clone_inner())
564    }
565
566    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
567        // SortMergeJoinExec uses symmetric hash partitioning where both left and right
568        // inputs are hash-partitioned on the join keys. This means partition `i` of the
569        // left input is joined with partition `i` of the right input.
570        //
571        // Therefore, partition-specific statistics can be computed by getting the
572        // partition-specific statistics from both children and combining them via
573        // `estimate_join_statistics`.
574        //
575        // TODO stats: it is not possible in general to know the output size of joins
576        // There are some special cases though, for example:
577        // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
578        let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(partition)?);
579        let right_stats =
580            Arc::unwrap_or_clone(self.right.partition_statistics(partition)?);
581        Ok(Arc::new(estimate_join_statistics(
582            left_stats,
583            right_stats,
584            &self.on,
585            &self.join_type,
586            &self.schema,
587        )?))
588    }
589
590    /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done,
591    /// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan.
592    /// Otherwise, it returns None.
593    fn try_swapping_with_projection(
594        &self,
595        projection: &ProjectionExec,
596    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
597        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
598        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
599        else {
600            return Ok(None);
601        };
602
603        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
604            self.left().schema().fields().len(),
605            &projection_as_columns,
606        );
607
608        if !join_allows_pushdown(
609            &projection_as_columns,
610            &self.schema(),
611            far_right_left_col_ind,
612            far_left_right_col_ind,
613        ) {
614            return Ok(None);
615        }
616
617        let Some(new_on) = update_join_on(
618            &projection_as_columns[0..=far_right_left_col_ind as _],
619            &projection_as_columns[far_left_right_col_ind as _..],
620            self.on(),
621            self.left().schema().fields().len(),
622        ) else {
623            return Ok(None);
624        };
625
626        let (new_left, new_right) = new_join_children(
627            &projection_as_columns,
628            far_right_left_col_ind,
629            far_left_right_col_ind,
630            self.children()[0],
631            self.children()[1],
632        )?;
633
634        Ok(Some(Arc::new(SortMergeJoinExec::try_new(
635            Arc::new(new_left),
636            Arc::new(new_right),
637            new_on,
638            self.filter.clone(),
639            self.join_type,
640            self.sort_options.clone(),
641            self.null_equality,
642        )?)))
643    }
644}