Skip to main content

datafusion_physical_plan/joins/sort_merge_join/
exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the Sort-Merge join execution plan.
19//! A Sort-Merge join plan consumes two sorted children plans and produces
20//! joined output by given join type and other options.
21
22use std::any::Any;
23use std::fmt::Formatter;
24use std::sync::Arc;
25
26use crate::execution_plan::{EmissionType, boundedness_from_children};
27use crate::expressions::PhysicalSortExpr;
28use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
29use crate::joins::sort_merge_join::stream::SortMergeJoinStream;
30use crate::joins::utils::{
31    JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid,
32    estimate_join_statistics, reorder_output_after_swap,
33    symmetric_join_output_partitioning,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
36use crate::projection::{
37    ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
38    physical_to_column_exprs, update_join_on,
39};
40use crate::{
41    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
42    PlanProperties, SendableRecordBatchStream, Statistics,
43};
44
45use arrow::compute::SortOptions;
46use arrow::datatypes::SchemaRef;
47use datafusion_common::{
48    JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err,
49    plan_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_execution::memory_pool::MemoryConsumer;
53use datafusion_physical_expr::equivalence::join_equivalence_properties;
54use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
55use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
56
57/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge
58/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large
59/// inputs where one or both of the inputs don't fit in the available memory.
60///
61/// # Join Expressions
62///
63/// Equi-join predicate (e.g. `<col1> = <col2>`) expressions are represented by [`Self::on`].
64///
65/// Non-equality predicates, which can not be pushed down to join inputs (e.g.
66/// `<col1> != <col2>`) are known as "filter expressions" and are evaluated
67/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional
68/// expressions.
69///
70/// # Sorting
71///
72/// Assumes that both the left and right input to the join are pre-sorted. It is not the
73/// responsibility of this execution plan to sort the inputs.
74///
75/// # "Streamed" vs "Buffered"
76///
77/// The number of record batches of streamed input currently present in the memory will depend
78/// on the output batch size of the execution plan. There is no spilling support for streamed input.
79/// The comparisons are performed from values of join keys in streamed input with the values of
80/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in
81/// buffered input batches. The streamed input is managed through the states in `StreamedState`
82/// and streamed input batches are represented by `StreamedBatch`.
83///
84/// Buffered input is buffered for all record batches having the same value of join key.
85/// If the memory limit increases beyond the specified value and spilling is enabled,
86/// buffered batches could be spilled to disk. If spilling is disabled, the execution
87/// will fail under the same conditions. Multiple record batches of buffered could currently reside
88/// in memory/disk during the execution. The number of buffered batches residing in
89/// memory/disk depends on the number of rows of buffered input having the same value
90/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs,
91/// the algorithm understands when it is not needed anymore, and releases the buffered batches
92/// from memory/disk. The buffered input is managed through the states in `BufferedState`
93/// and buffered input batches are represented by `BufferedBatch`.
94///
95/// Depending on the type of join, left or right input may be selected as streamed or buffered
96/// respectively. For example, in a left-outer join, the left execution plan will be selected as
97/// streamed input while in a right-outer join, the right execution plan will be selected as the
98/// streamed input.
99///
100/// Reference for the algorithm:
101/// <https://en.wikipedia.org/wiki/Sort-merge_join>.
102///
103/// Helpful short video demonstration:
104/// <https://www.youtube.com/watch?v=jiWCPJtDE2c>.
105#[derive(Debug, Clone)]
106pub struct SortMergeJoinExec {
107    /// Left sorted joining execution plan
108    pub left: Arc<dyn ExecutionPlan>,
109    /// Right sorting joining execution plan
110    pub right: Arc<dyn ExecutionPlan>,
111    /// Set of common columns used to join on
112    pub on: JoinOn,
113    /// Filters which are applied while finding matching rows
114    pub filter: Option<JoinFilter>,
115    /// How the join is performed
116    pub join_type: JoinType,
117    /// The schema once the join is applied
118    schema: SchemaRef,
119    /// Execution metrics
120    metrics: ExecutionPlanMetricsSet,
121    /// The left SortExpr
122    left_sort_exprs: LexOrdering,
123    /// The right SortExpr
124    right_sort_exprs: LexOrdering,
125    /// Sort options of join columns used in sorting left and right execution plans
126    pub sort_options: Vec<SortOptions>,
127    /// Defines the null equality for the join.
128    pub null_equality: NullEquality,
129    /// Cache holding plan properties like equivalences, output partitioning etc.
130    cache: PlanProperties,
131}
132
133impl SortMergeJoinExec {
134    /// Tries to create a new [SortMergeJoinExec].
135    /// The inputs are sorted using `sort_options` are applied to the columns in the `on`
136    /// # Error
137    /// This function errors when it is not possible to join the left and right sides on keys `on`.
138    pub fn try_new(
139        left: Arc<dyn ExecutionPlan>,
140        right: Arc<dyn ExecutionPlan>,
141        on: JoinOn,
142        filter: Option<JoinFilter>,
143        join_type: JoinType,
144        sort_options: Vec<SortOptions>,
145        null_equality: NullEquality,
146    ) -> Result<Self> {
147        let left_schema = left.schema();
148        let right_schema = right.schema();
149
150        check_join_is_valid(&left_schema, &right_schema, &on)?;
151        if sort_options.len() != on.len() {
152            return plan_err!(
153                "Expected number of sort options: {}, actual: {}",
154                on.len(),
155                sort_options.len()
156            );
157        }
158
159        let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
160            .iter()
161            .zip(sort_options.iter())
162            .map(|((l, r), sort_op)| {
163                let left = PhysicalSortExpr {
164                    expr: Arc::clone(l),
165                    options: *sort_op,
166                };
167                let right = PhysicalSortExpr {
168                    expr: Arc::clone(r),
169                    options: *sort_op,
170                };
171                (left, right)
172            })
173            .unzip();
174        let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
175            return plan_err!(
176                "SortMergeJoinExec requires valid sort expressions for its left side"
177            );
178        };
179        let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
180            return plan_err!(
181                "SortMergeJoinExec requires valid sort expressions for its right side"
182            );
183        };
184
185        let schema =
186            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
187        let cache =
188            Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
189        Ok(Self {
190            left,
191            right,
192            on,
193            filter,
194            join_type,
195            schema,
196            metrics: ExecutionPlanMetricsSet::new(),
197            left_sort_exprs,
198            right_sort_exprs,
199            sort_options,
200            null_equality,
201            cache,
202        })
203    }
204
205    /// Get probe side (e.g streaming side) information for this sort merge join.
206    /// In current implementation, probe side is determined according to join type.
207    pub fn probe_side(join_type: &JoinType) -> JoinSide {
208        // When output schema contains only the right side, probe side is right.
209        // Otherwise probe side is the left side.
210        match join_type {
211            // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226)
212            JoinType::Right
213            | JoinType::RightSemi
214            | JoinType::RightAnti
215            | JoinType::RightMark => JoinSide::Right,
216            JoinType::Inner
217            | JoinType::Left
218            | JoinType::Full
219            | JoinType::LeftAnti
220            | JoinType::LeftSemi
221            | JoinType::LeftMark => JoinSide::Left,
222        }
223    }
224
225    /// Calculate order preservation flags for this sort merge join.
226    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
227        match join_type {
228            JoinType::Inner => vec![true, false],
229            JoinType::Left
230            | JoinType::LeftSemi
231            | JoinType::LeftAnti
232            | JoinType::LeftMark => vec![true, false],
233            JoinType::Right
234            | JoinType::RightSemi
235            | JoinType::RightAnti
236            | JoinType::RightMark => {
237                vec![false, true]
238            }
239            _ => vec![false, false],
240        }
241    }
242
243    /// Set of common columns used to join on
244    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
245        &self.on
246    }
247
248    /// Ref to right execution plan
249    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
250        &self.right
251    }
252
253    /// Join type
254    pub fn join_type(&self) -> JoinType {
255        self.join_type
256    }
257
258    /// Ref to left execution plan
259    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
260        &self.left
261    }
262
263    /// Ref to join filter
264    pub fn filter(&self) -> &Option<JoinFilter> {
265        &self.filter
266    }
267
268    /// Ref to sort options
269    pub fn sort_options(&self) -> &[SortOptions] {
270        &self.sort_options
271    }
272
273    /// Null equality
274    pub fn null_equality(&self) -> NullEquality {
275        self.null_equality
276    }
277
278    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
279    fn compute_properties(
280        left: &Arc<dyn ExecutionPlan>,
281        right: &Arc<dyn ExecutionPlan>,
282        schema: SchemaRef,
283        join_type: JoinType,
284        join_on: JoinOnRef,
285    ) -> Result<PlanProperties> {
286        // Calculate equivalence properties:
287        let eq_properties = join_equivalence_properties(
288            left.equivalence_properties().clone(),
289            right.equivalence_properties().clone(),
290            &join_type,
291            schema,
292            &Self::maintains_input_order(join_type),
293            Some(Self::probe_side(&join_type)),
294            join_on,
295        )?;
296
297        let output_partitioning =
298            symmetric_join_output_partitioning(left, right, &join_type)?;
299
300        Ok(PlanProperties::new(
301            eq_properties,
302            output_partitioning,
303            EmissionType::Incremental,
304            boundedness_from_children([left, right]),
305        ))
306    }
307
308    /// # Notes:
309    ///
310    /// This function should be called BEFORE inserting any repartitioning
311    /// operators on the join's children. Check [`super::super::HashJoinExec::swap_inputs`]
312    /// for more details.
313    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
314        let left = self.left();
315        let right = self.right();
316        let new_join = SortMergeJoinExec::try_new(
317            Arc::clone(right),
318            Arc::clone(left),
319            self.on()
320                .iter()
321                .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
322                .collect::<Vec<_>>(),
323            self.filter().as_ref().map(JoinFilter::swap),
324            self.join_type().swap(),
325            self.sort_options.clone(),
326            self.null_equality,
327        )?;
328
329        // TODO: OR this condition with having a built-in projection (like
330        //       ordinary hash join) when we support it.
331        if matches!(
332            self.join_type(),
333            JoinType::LeftSemi
334                | JoinType::RightSemi
335                | JoinType::LeftAnti
336                | JoinType::RightAnti
337        ) {
338            Ok(Arc::new(new_join))
339        } else {
340            reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
341        }
342    }
343}
344
345impl DisplayAs for SortMergeJoinExec {
346    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
347        match t {
348            DisplayFormatType::Default | DisplayFormatType::Verbose => {
349                let on = self
350                    .on
351                    .iter()
352                    .map(|(c1, c2)| format!("({c1}, {c2})"))
353                    .collect::<Vec<String>>()
354                    .join(", ");
355                let display_null_equality =
356                    if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
357                        ", NullsEqual: true"
358                    } else {
359                        ""
360                    };
361                write!(
362                    f,
363                    "{}: join_type={:?}, on=[{}]{}{}",
364                    Self::static_name(),
365                    self.join_type,
366                    on,
367                    self.filter.as_ref().map_or_else(
368                        || "".to_string(),
369                        |f| format!(", filter={}", f.expression())
370                    ),
371                    display_null_equality,
372                )
373            }
374            DisplayFormatType::TreeRender => {
375                let on = self
376                    .on
377                    .iter()
378                    .map(|(c1, c2)| {
379                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
380                    })
381                    .collect::<Vec<String>>()
382                    .join(", ");
383
384                if self.join_type() != JoinType::Inner {
385                    writeln!(f, "join_type={:?}", self.join_type)?;
386                }
387                writeln!(f, "on={on}")?;
388
389                if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
390                    writeln!(f, "NullsEqual: true")?;
391                }
392
393                Ok(())
394            }
395        }
396    }
397}
398
399impl ExecutionPlan for SortMergeJoinExec {
400    fn name(&self) -> &'static str {
401        "SortMergeJoinExec"
402    }
403
404    fn as_any(&self) -> &dyn Any {
405        self
406    }
407
408    fn properties(&self) -> &PlanProperties {
409        &self.cache
410    }
411
412    fn required_input_distribution(&self) -> Vec<Distribution> {
413        let (left_expr, right_expr) = self
414            .on
415            .iter()
416            .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
417            .unzip();
418        vec![
419            Distribution::HashPartitioned(left_expr),
420            Distribution::HashPartitioned(right_expr),
421        ]
422    }
423
424    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
425        vec![
426            Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
427            Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
428        ]
429    }
430
431    fn maintains_input_order(&self) -> Vec<bool> {
432        Self::maintains_input_order(self.join_type)
433    }
434
435    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
436        vec![&self.left, &self.right]
437    }
438
439    fn with_new_children(
440        self: Arc<Self>,
441        children: Vec<Arc<dyn ExecutionPlan>>,
442    ) -> Result<Arc<dyn ExecutionPlan>> {
443        match &children[..] {
444            [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
445                Arc::clone(left),
446                Arc::clone(right),
447                self.on.clone(),
448                self.filter.clone(),
449                self.join_type,
450                self.sort_options.clone(),
451                self.null_equality,
452            )?)),
453            _ => internal_err!("SortMergeJoin wrong number of children"),
454        }
455    }
456
457    fn execute(
458        &self,
459        partition: usize,
460        context: Arc<TaskContext>,
461    ) -> Result<SendableRecordBatchStream> {
462        let left_partitions = self.left.output_partitioning().partition_count();
463        let right_partitions = self.right.output_partitioning().partition_count();
464        assert_eq_or_internal_err!(
465            left_partitions,
466            right_partitions,
467            "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
468                 consider using RepartitionExec"
469        );
470        let (on_left, on_right) = self.on.iter().cloned().unzip();
471        let (streamed, buffered, on_streamed, on_buffered) =
472            if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
473                (
474                    Arc::clone(&self.left),
475                    Arc::clone(&self.right),
476                    on_left,
477                    on_right,
478                )
479            } else {
480                (
481                    Arc::clone(&self.right),
482                    Arc::clone(&self.left),
483                    on_right,
484                    on_left,
485                )
486            };
487
488        // execute children plans
489        let streamed = streamed.execute(partition, Arc::clone(&context))?;
490        let buffered = buffered.execute(partition, Arc::clone(&context))?;
491
492        // create output buffer
493        let batch_size = context.session_config().batch_size();
494
495        // create memory reservation
496        let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
497            .register(context.memory_pool());
498
499        // create join stream
500        Ok(Box::pin(SortMergeJoinStream::try_new(
501            context.session_config().spill_compression(),
502            Arc::clone(&self.schema),
503            self.sort_options.clone(),
504            self.null_equality,
505            streamed,
506            buffered,
507            on_streamed,
508            on_buffered,
509            self.filter.clone(),
510            self.join_type,
511            batch_size,
512            SortMergeJoinMetrics::new(partition, &self.metrics),
513            reservation,
514            context.runtime_env(),
515        )?))
516    }
517
518    fn metrics(&self) -> Option<MetricsSet> {
519        Some(self.metrics.clone_inner())
520    }
521
522    fn statistics(&self) -> Result<Statistics> {
523        self.partition_statistics(None)
524    }
525
526    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
527        // SortMergeJoinExec uses symmetric hash partitioning where both left and right
528        // inputs are hash-partitioned on the join keys. This means partition `i` of the
529        // left input is joined with partition `i` of the right input.
530        //
531        // Therefore, partition-specific statistics can be computed by getting the
532        // partition-specific statistics from both children and combining them via
533        // `estimate_join_statistics`.
534        //
535        // TODO stats: it is not possible in general to know the output size of joins
536        // There are some special cases though, for example:
537        // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
538        estimate_join_statistics(
539            self.left.partition_statistics(partition)?,
540            self.right.partition_statistics(partition)?,
541            &self.on,
542            &self.join_type,
543            &self.schema,
544        )
545    }
546
547    /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done,
548    /// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan.
549    /// Otherwise, it returns None.
550    fn try_swapping_with_projection(
551        &self,
552        projection: &ProjectionExec,
553    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
554        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
555        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
556        else {
557            return Ok(None);
558        };
559
560        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
561            self.left().schema().fields().len(),
562            &projection_as_columns,
563        );
564
565        if !join_allows_pushdown(
566            &projection_as_columns,
567            &self.schema(),
568            far_right_left_col_ind,
569            far_left_right_col_ind,
570        ) {
571            return Ok(None);
572        }
573
574        let Some(new_on) = update_join_on(
575            &projection_as_columns[0..=far_right_left_col_ind as _],
576            &projection_as_columns[far_left_right_col_ind as _..],
577            self.on(),
578            self.left().schema().fields().len(),
579        ) else {
580            return Ok(None);
581        };
582
583        let (new_left, new_right) = new_join_children(
584            &projection_as_columns,
585            far_right_left_col_ind,
586            far_left_right_col_ind,
587            self.children()[0],
588            self.children()[1],
589        )?;
590
591        Ok(Some(Arc::new(SortMergeJoinExec::try_new(
592            Arc::new(new_left),
593            Arc::new(new_right),
594            new_on,
595            self.filter.clone(),
596            self.join_type,
597            self.sort_options.clone(),
598            self.null_equality,
599        )?)))
600    }
601}