datafusion_physical_plan/joins/sort_merge_join/
exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the Sort-Merge join execution plan.
19//! A Sort-Merge join plan consumes two sorted children plans and produces
20//! joined output by given join type and other options.
21
22use std::any::Any;
23use std::fmt::Formatter;
24use std::sync::Arc;
25
26use crate::execution_plan::{boundedness_from_children, EmissionType};
27use crate::expressions::PhysicalSortExpr;
28use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
29use crate::joins::sort_merge_join::stream::SortMergeJoinStream;
30use crate::joins::utils::{
31    build_join_schema, check_join_is_valid, estimate_join_statistics,
32    reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn,
33    JoinOnRef,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
36use crate::projection::{
37    join_allows_pushdown, join_table_borders, new_join_children,
38    physical_to_column_exprs, update_join_on, ProjectionExec,
39};
40use crate::{
41    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
42    PlanProperties, SendableRecordBatchStream, Statistics,
43};
44
45use arrow::compute::SortOptions;
46use arrow::datatypes::SchemaRef;
47use datafusion_common::{
48    internal_err, plan_err, JoinSide, JoinType, NullEquality, Result,
49};
50use datafusion_execution::memory_pool::MemoryConsumer;
51use datafusion_execution::TaskContext;
52use datafusion_physical_expr::equivalence::join_equivalence_properties;
53use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef};
54use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
55
56/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge
57/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large
58/// inputs where one or both of the inputs don't fit in the available memory.
59///
60/// # Join Expressions
61///
62/// Equi-join predicate (e.g. `<col1> = <col2>`) expressions are represented by [`Self::on`].
63///
64/// Non-equality predicates, which can not be pushed down to join inputs (e.g.
65/// `<col1> != <col2>`) are known as "filter expressions" and are evaluated
66/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional
67/// expressions.
68///
69/// # Sorting
70///
71/// Assumes that both the left and right input to the join are pre-sorted. It is not the
72/// responsibility of this execution plan to sort the inputs.
73///
74/// # "Streamed" vs "Buffered"
75///
76/// The number of record batches of streamed input currently present in the memory will depend
77/// on the output batch size of the execution plan. There is no spilling support for streamed input.
78/// The comparisons are performed from values of join keys in streamed input with the values of
79/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in
80/// buffered input batches. The streamed input is managed through the states in `StreamedState`
81/// and streamed input batches are represented by `StreamedBatch`.
82///
83/// Buffered input is buffered for all record batches having the same value of join key.
84/// If the memory limit increases beyond the specified value and spilling is enabled,
85/// buffered batches could be spilled to disk. If spilling is disabled, the execution
86/// will fail under the same conditions. Multiple record batches of buffered could currently reside
87/// in memory/disk during the execution. The number of buffered batches residing in
88/// memory/disk depends on the number of rows of buffered input having the same value
89/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs,
90/// the algorithm understands when it is not needed anymore, and releases the buffered batches
91/// from memory/disk. The buffered input is managed through the states in `BufferedState`
92/// and buffered input batches are represented by `BufferedBatch`.
93///
94/// Depending on the type of join, left or right input may be selected as streamed or buffered
95/// respectively. For example, in a left-outer join, the left execution plan will be selected as
96/// streamed input while in a right-outer join, the right execution plan will be selected as the
97/// streamed input.
98///
99/// Reference for the algorithm:
100/// <https://en.wikipedia.org/wiki/Sort-merge_join>.
101///
102/// Helpful short video demonstration:
103/// <https://www.youtube.com/watch?v=jiWCPJtDE2c>.
104#[derive(Debug, Clone)]
105pub struct SortMergeJoinExec {
106    /// Left sorted joining execution plan
107    pub left: Arc<dyn ExecutionPlan>,
108    /// Right sorting joining execution plan
109    pub right: Arc<dyn ExecutionPlan>,
110    /// Set of common columns used to join on
111    pub on: JoinOn,
112    /// Filters which are applied while finding matching rows
113    pub filter: Option<JoinFilter>,
114    /// How the join is performed
115    pub join_type: JoinType,
116    /// The schema once the join is applied
117    schema: SchemaRef,
118    /// Execution metrics
119    metrics: ExecutionPlanMetricsSet,
120    /// The left SortExpr
121    left_sort_exprs: LexOrdering,
122    /// The right SortExpr
123    right_sort_exprs: LexOrdering,
124    /// Sort options of join columns used in sorting left and right execution plans
125    pub sort_options: Vec<SortOptions>,
126    /// Defines the null equality for the join.
127    pub null_equality: NullEquality,
128    /// Cache holding plan properties like equivalences, output partitioning etc.
129    cache: PlanProperties,
130}
131
132impl SortMergeJoinExec {
133    /// Tries to create a new [SortMergeJoinExec].
134    /// The inputs are sorted using `sort_options` are applied to the columns in the `on`
135    /// # Error
136    /// This function errors when it is not possible to join the left and right sides on keys `on`.
137    pub fn try_new(
138        left: Arc<dyn ExecutionPlan>,
139        right: Arc<dyn ExecutionPlan>,
140        on: JoinOn,
141        filter: Option<JoinFilter>,
142        join_type: JoinType,
143        sort_options: Vec<SortOptions>,
144        null_equality: NullEquality,
145    ) -> Result<Self> {
146        let left_schema = left.schema();
147        let right_schema = right.schema();
148
149        check_join_is_valid(&left_schema, &right_schema, &on)?;
150        if sort_options.len() != on.len() {
151            return plan_err!(
152                "Expected number of sort options: {}, actual: {}",
153                on.len(),
154                sort_options.len()
155            );
156        }
157
158        let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
159            .iter()
160            .zip(sort_options.iter())
161            .map(|((l, r), sort_op)| {
162                let left = PhysicalSortExpr {
163                    expr: Arc::clone(l),
164                    options: *sort_op,
165                };
166                let right = PhysicalSortExpr {
167                    expr: Arc::clone(r),
168                    options: *sort_op,
169                };
170                (left, right)
171            })
172            .unzip();
173        let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
174            return plan_err!(
175                "SortMergeJoinExec requires valid sort expressions for its left side"
176            );
177        };
178        let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
179            return plan_err!(
180                "SortMergeJoinExec requires valid sort expressions for its right side"
181            );
182        };
183
184        let schema =
185            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
186        let cache =
187            Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
188        Ok(Self {
189            left,
190            right,
191            on,
192            filter,
193            join_type,
194            schema,
195            metrics: ExecutionPlanMetricsSet::new(),
196            left_sort_exprs,
197            right_sort_exprs,
198            sort_options,
199            null_equality,
200            cache,
201        })
202    }
203
204    /// Get probe side (e.g streaming side) information for this sort merge join.
205    /// In current implementation, probe side is determined according to join type.
206    pub fn probe_side(join_type: &JoinType) -> JoinSide {
207        // When output schema contains only the right side, probe side is right.
208        // Otherwise probe side is the left side.
209        match join_type {
210            // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226)
211            JoinType::Right
212            | JoinType::RightSemi
213            | JoinType::RightAnti
214            | JoinType::RightMark => JoinSide::Right,
215            JoinType::Inner
216            | JoinType::Left
217            | JoinType::Full
218            | JoinType::LeftAnti
219            | JoinType::LeftSemi
220            | JoinType::LeftMark => JoinSide::Left,
221        }
222    }
223
224    /// Calculate order preservation flags for this sort merge join.
225    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
226        match join_type {
227            JoinType::Inner => vec![true, false],
228            JoinType::Left
229            | JoinType::LeftSemi
230            | JoinType::LeftAnti
231            | JoinType::LeftMark => vec![true, false],
232            JoinType::Right
233            | JoinType::RightSemi
234            | JoinType::RightAnti
235            | JoinType::RightMark => {
236                vec![false, true]
237            }
238            _ => vec![false, false],
239        }
240    }
241
242    /// Set of common columns used to join on
243    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
244        &self.on
245    }
246
247    /// Ref to right execution plan
248    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
249        &self.right
250    }
251
252    /// Join type
253    pub fn join_type(&self) -> JoinType {
254        self.join_type
255    }
256
257    /// Ref to left execution plan
258    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
259        &self.left
260    }
261
262    /// Ref to join filter
263    pub fn filter(&self) -> &Option<JoinFilter> {
264        &self.filter
265    }
266
267    /// Ref to sort options
268    pub fn sort_options(&self) -> &[SortOptions] {
269        &self.sort_options
270    }
271
272    /// Null equality
273    pub fn null_equality(&self) -> NullEquality {
274        self.null_equality
275    }
276
277    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
278    fn compute_properties(
279        left: &Arc<dyn ExecutionPlan>,
280        right: &Arc<dyn ExecutionPlan>,
281        schema: SchemaRef,
282        join_type: JoinType,
283        join_on: JoinOnRef,
284    ) -> Result<PlanProperties> {
285        // Calculate equivalence properties:
286        let eq_properties = join_equivalence_properties(
287            left.equivalence_properties().clone(),
288            right.equivalence_properties().clone(),
289            &join_type,
290            schema,
291            &Self::maintains_input_order(join_type),
292            Some(Self::probe_side(&join_type)),
293            join_on,
294        )?;
295
296        let output_partitioning =
297            symmetric_join_output_partitioning(left, right, &join_type)?;
298
299        Ok(PlanProperties::new(
300            eq_properties,
301            output_partitioning,
302            EmissionType::Incremental,
303            boundedness_from_children([left, right]),
304        ))
305    }
306
307    /// # Notes:
308    ///
309    /// This function should be called BEFORE inserting any repartitioning
310    /// operators on the join's children. Check [`super::super::HashJoinExec::swap_inputs`]
311    /// for more details.
312    pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
313        let left = self.left();
314        let right = self.right();
315        let new_join = SortMergeJoinExec::try_new(
316            Arc::clone(right),
317            Arc::clone(left),
318            self.on()
319                .iter()
320                .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
321                .collect::<Vec<_>>(),
322            self.filter().as_ref().map(JoinFilter::swap),
323            self.join_type().swap(),
324            self.sort_options.clone(),
325            self.null_equality,
326        )?;
327
328        // TODO: OR this condition with having a built-in projection (like
329        //       ordinary hash join) when we support it.
330        if matches!(
331            self.join_type(),
332            JoinType::LeftSemi
333                | JoinType::RightSemi
334                | JoinType::LeftAnti
335                | JoinType::RightAnti
336        ) {
337            Ok(Arc::new(new_join))
338        } else {
339            reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
340        }
341    }
342}
343
344impl DisplayAs for SortMergeJoinExec {
345    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
346        match t {
347            DisplayFormatType::Default | DisplayFormatType::Verbose => {
348                let on = self
349                    .on
350                    .iter()
351                    .map(|(c1, c2)| format!("({c1}, {c2})"))
352                    .collect::<Vec<String>>()
353                    .join(", ");
354                let display_null_equality =
355                    if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
356                        ", NullsEqual: true"
357                    } else {
358                        ""
359                    };
360                write!(
361                    f,
362                    "SortMergeJoin: join_type={:?}, on=[{}]{}{}",
363                    self.join_type,
364                    on,
365                    self.filter.as_ref().map_or_else(
366                        || "".to_string(),
367                        |f| format!(", filter={}", f.expression())
368                    ),
369                    display_null_equality,
370                )
371            }
372            DisplayFormatType::TreeRender => {
373                let on = self
374                    .on
375                    .iter()
376                    .map(|(c1, c2)| {
377                        format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
378                    })
379                    .collect::<Vec<String>>()
380                    .join(", ");
381
382                if self.join_type() != JoinType::Inner {
383                    writeln!(f, "join_type={:?}", self.join_type)?;
384                }
385                writeln!(f, "on={on}")?;
386
387                if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
388                    writeln!(f, "NullsEqual: true")?;
389                }
390
391                Ok(())
392            }
393        }
394    }
395}
396
397impl ExecutionPlan for SortMergeJoinExec {
398    fn name(&self) -> &'static str {
399        "SortMergeJoinExec"
400    }
401
402    fn as_any(&self) -> &dyn Any {
403        self
404    }
405
406    fn properties(&self) -> &PlanProperties {
407        &self.cache
408    }
409
410    fn required_input_distribution(&self) -> Vec<Distribution> {
411        let (left_expr, right_expr) = self
412            .on
413            .iter()
414            .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
415            .unzip();
416        vec![
417            Distribution::HashPartitioned(left_expr),
418            Distribution::HashPartitioned(right_expr),
419        ]
420    }
421
422    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
423        vec![
424            Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
425            Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
426        ]
427    }
428
429    fn maintains_input_order(&self) -> Vec<bool> {
430        Self::maintains_input_order(self.join_type)
431    }
432
433    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
434        vec![&self.left, &self.right]
435    }
436
437    fn with_new_children(
438        self: Arc<Self>,
439        children: Vec<Arc<dyn ExecutionPlan>>,
440    ) -> Result<Arc<dyn ExecutionPlan>> {
441        match &children[..] {
442            [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
443                Arc::clone(left),
444                Arc::clone(right),
445                self.on.clone(),
446                self.filter.clone(),
447                self.join_type,
448                self.sort_options.clone(),
449                self.null_equality,
450            )?)),
451            _ => internal_err!("SortMergeJoin wrong number of children"),
452        }
453    }
454
455    fn execute(
456        &self,
457        partition: usize,
458        context: Arc<TaskContext>,
459    ) -> Result<SendableRecordBatchStream> {
460        let left_partitions = self.left.output_partitioning().partition_count();
461        let right_partitions = self.right.output_partitioning().partition_count();
462        if left_partitions != right_partitions {
463            return internal_err!(
464                "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
465                 consider using RepartitionExec"
466            );
467        }
468        let (on_left, on_right) = self.on.iter().cloned().unzip();
469        let (streamed, buffered, on_streamed, on_buffered) =
470            if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
471                (
472                    Arc::clone(&self.left),
473                    Arc::clone(&self.right),
474                    on_left,
475                    on_right,
476                )
477            } else {
478                (
479                    Arc::clone(&self.right),
480                    Arc::clone(&self.left),
481                    on_right,
482                    on_left,
483                )
484            };
485
486        // execute children plans
487        let streamed = streamed.execute(partition, Arc::clone(&context))?;
488        let buffered = buffered.execute(partition, Arc::clone(&context))?;
489
490        // create output buffer
491        let batch_size = context.session_config().batch_size();
492
493        // create memory reservation
494        let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
495            .register(context.memory_pool());
496
497        // create join stream
498        Ok(Box::pin(SortMergeJoinStream::try_new(
499            context.session_config().spill_compression(),
500            Arc::clone(&self.schema),
501            self.sort_options.clone(),
502            self.null_equality,
503            streamed,
504            buffered,
505            on_streamed,
506            on_buffered,
507            self.filter.clone(),
508            self.join_type,
509            batch_size,
510            SortMergeJoinMetrics::new(partition, &self.metrics),
511            reservation,
512            context.runtime_env(),
513        )?))
514    }
515
516    fn metrics(&self) -> Option<MetricsSet> {
517        Some(self.metrics.clone_inner())
518    }
519
520    fn statistics(&self) -> Result<Statistics> {
521        self.partition_statistics(None)
522    }
523
524    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
525        if partition.is_some() {
526            return Ok(Statistics::new_unknown(&self.schema()));
527        }
528        // TODO stats: it is not possible in general to know the output size of joins
529        // There are some special cases though, for example:
530        // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
531        estimate_join_statistics(
532            self.left.partition_statistics(None)?,
533            self.right.partition_statistics(None)?,
534            self.on.clone(),
535            &self.join_type,
536            &self.schema,
537        )
538    }
539
540    /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done,
541    /// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan.
542    /// Otherwise, it returns None.
543    fn try_swapping_with_projection(
544        &self,
545        projection: &ProjectionExec,
546    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
547        // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed.
548        let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
549        else {
550            return Ok(None);
551        };
552
553        let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
554            self.left().schema().fields().len(),
555            &projection_as_columns,
556        );
557
558        if !join_allows_pushdown(
559            &projection_as_columns,
560            &self.schema(),
561            far_right_left_col_ind,
562            far_left_right_col_ind,
563        ) {
564            return Ok(None);
565        }
566
567        let Some(new_on) = update_join_on(
568            &projection_as_columns[0..=far_right_left_col_ind as _],
569            &projection_as_columns[far_left_right_col_ind as _..],
570            self.on(),
571            self.left().schema().fields().len(),
572        ) else {
573            return Ok(None);
574        };
575
576        let (new_left, new_right) = new_join_children(
577            &projection_as_columns,
578            far_right_left_col_ind,
579            far_left_right_col_ind,
580            self.children()[0],
581            self.children()[1],
582        )?;
583
584        Ok(Some(Arc::new(SortMergeJoinExec::try_new(
585            Arc::new(new_left),
586            Arc::new(new_right),
587            new_on,
588            self.filter.clone(),
589            self.join_type,
590            self.sort_options.clone(),
591            self.null_equality,
592        )?)))
593    }
594}