Skip to main content

datafusion_physical_plan/sorts/
sort_preserving_merge.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream.
19
20use std::sync::Arc;
21
22use crate::common::spawn_buffered;
23use crate::limit::LimitStream;
24use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
25use crate::projection::{ProjectionExec, make_with_child, update_ordering};
26use crate::sorts::streaming_merge::StreamingMergeBuilder;
27use crate::{
28    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
29    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
30    check_if_same_properties,
31};
32
33use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
34use datafusion_execution::TaskContext;
35use datafusion_execution::memory_pool::MemoryConsumer;
36use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
37
38use crate::execution_plan::{EvaluationType, SchedulingType};
39use log::{debug, trace};
40
41/// Sort preserving merge execution plan
42///
43/// # Overview
44///
45/// This operator implements a K-way merge. It is used to merge multiple sorted
46/// streams into a single sorted stream and is highly optimized.
47///
48/// ## Inputs:
49///
50/// 1. A list of sort expressions
51/// 2. An input plan, where each partition is sorted with respect to
52///    these sort expressions.
53///
54/// ## Output:
55///
56/// 1. A single partition that is also sorted with respect to the expressions
57///
58/// ## Diagram
59///
60/// ```text
61/// ┌─────────────────────────┐
62/// │ ┌───┬───┬───┬───┐       │
63/// │ │ A │ B │ C │ D │ ...   │──┐
64/// │ └───┴───┴───┴───┘       │  │
65/// └─────────────────────────┘  │  ┌───────────────────┐    ┌───────────────────────────────┐
66///   Stream 1                   │  │                   │    │ ┌───┬───╦═══╦───┬───╦═══╗     │
67///                              ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │
68///                              │  │                   │    │ └───┴─▲─╩═══╩───┴───╩═══╝     │
69/// ┌─────────────────────────┐  │  └───────────────────┘    └─┬─────┴───────────────────────┘
70/// │ ╔═══╦═══╗               │  │
71/// │ ║ B ║ E ║     ...       │──┘                             │
72/// │ ╚═══╩═══╝               │              Stable sort if `enable_round_robin_repartition=false`:
73/// └─────────────────────────┘              the merged stream places equal rows from stream 1
74///   Stream 2
75///
76///
77///  Input Partitions                                          Output Partition
78///    (sorted)                                                  (sorted)
79/// ```
80///
81/// # Error Handling
82///
83/// If any of the input partitions return an error, the error is propagated to
84/// the output and inputs are not polled again.
85#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87    /// Input plan with sorted partitions
88    input: Arc<dyn ExecutionPlan>,
89    /// Sort expressions
90    expr: LexOrdering,
91    /// Execution metrics
92    metrics: ExecutionPlanMetricsSet,
93    /// Optional number of rows to fetch. Stops producing rows after this fetch
94    fetch: Option<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: Arc<PlanProperties>,
97    /// Use round-robin selection of tied winners of loser tree
98    ///
99    /// See [`Self::with_round_robin_repartition`] for more information.
100    enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104    /// Create a new sort execution plan
105    pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106        let cache = Self::compute_properties(&input, expr.clone());
107        Self {
108            input,
109            expr,
110            metrics: ExecutionPlanMetricsSet::new(),
111            fetch: None,
112            cache: Arc::new(cache),
113            enable_round_robin_repartition: true,
114        }
115    }
116
117    /// Sets the number of rows to fetch
118    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119        self.fetch = fetch;
120        self
121    }
122
123    /// Sets the selection strategy of tied winners of the loser tree algorithm
124    ///
125    /// If true (the default) equal output rows are placed in the merged stream
126    /// in round robin fashion. This approach consumes input streams at more
127    /// even rates when there are many rows with the same sort key.
128    ///
129    /// If false, equal output rows are always placed in the merged stream in
130    /// the order of the inputs, resulting in potentially slower execution but a
131    /// stable output order.
132    pub fn with_round_robin_repartition(
133        mut self,
134        enable_round_robin_repartition: bool,
135    ) -> Self {
136        self.enable_round_robin_repartition = enable_round_robin_repartition;
137        self
138    }
139
140    /// Input schema
141    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142        &self.input
143    }
144
145    /// Sort expressions
146    pub fn expr(&self) -> &LexOrdering {
147        &self.expr
148    }
149
150    /// Fetch
151    pub fn fetch(&self) -> Option<usize> {
152        self.fetch
153    }
154
155    /// Creates the cache object that stores the plan properties
156    /// such as schema, equivalence properties, ordering, partitioning, etc.
157    fn compute_properties(
158        input: &Arc<dyn ExecutionPlan>,
159        ordering: LexOrdering,
160    ) -> PlanProperties {
161        let input_partitions = input.output_partitioning().partition_count();
162        let (drive, scheduling) = if input_partitions > 1 {
163            (EvaluationType::Eager, SchedulingType::Cooperative)
164        } else {
165            (
166                input.properties().evaluation_type,
167                input.properties().scheduling_type,
168            )
169        };
170
171        let mut eq_properties = input.equivalence_properties().clone();
172        eq_properties.clear_per_partition_constants();
173        eq_properties.add_ordering(ordering);
174        PlanProperties::new(
175            eq_properties,                        // Equivalence Properties
176            Partitioning::UnknownPartitioning(1), // Output Partitioning
177            input.pipeline_behavior(),            // Pipeline Behavior
178            input.boundedness(),                  // Boundedness
179        )
180        .with_evaluation_type(drive)
181        .with_scheduling_type(scheduling)
182    }
183
184    fn with_new_children_and_same_properties(
185        &self,
186        mut children: Vec<Arc<dyn ExecutionPlan>>,
187    ) -> Self {
188        Self {
189            input: children.swap_remove(0),
190            metrics: ExecutionPlanMetricsSet::new(),
191            ..Self::clone(self)
192        }
193    }
194}
195
196impl DisplayAs for SortPreservingMergeExec {
197    fn fmt_as(
198        &self,
199        t: DisplayFormatType,
200        f: &mut std::fmt::Formatter,
201    ) -> std::fmt::Result {
202        match t {
203            DisplayFormatType::Default | DisplayFormatType::Verbose => {
204                write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
205                if let Some(fetch) = self.fetch {
206                    write!(f, ", fetch={fetch}")?;
207                };
208
209                Ok(())
210            }
211            DisplayFormatType::TreeRender => {
212                if let Some(fetch) = self.fetch {
213                    writeln!(f, "limit={fetch}")?;
214                };
215
216                for (i, e) in self.expr().iter().enumerate() {
217                    e.fmt_sql(f)?;
218                    if i != self.expr().len() - 1 {
219                        write!(f, ", ")?;
220                    }
221                }
222
223                Ok(())
224            }
225        }
226    }
227}
228
229impl ExecutionPlan for SortPreservingMergeExec {
230    fn name(&self) -> &'static str {
231        "SortPreservingMergeExec"
232    }
233
234    /// Return a reference to Any that can be used for downcasting
235    fn properties(&self) -> &Arc<PlanProperties> {
236        &self.cache
237    }
238
239    fn fetch(&self) -> Option<usize> {
240        self.fetch
241    }
242
243    /// Sets the number of rows to fetch
244    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
245        Some(Arc::new(Self {
246            input: Arc::clone(&self.input),
247            expr: self.expr.clone(),
248            metrics: self.metrics.clone(),
249            fetch: limit,
250            cache: Arc::clone(&self.cache),
251            enable_round_robin_repartition: self.enable_round_robin_repartition,
252        }))
253    }
254
255    fn with_preserve_order(
256        &self,
257        preserve_order: bool,
258    ) -> Option<Arc<dyn ExecutionPlan>> {
259        self.input
260            .with_preserve_order(preserve_order)
261            .and_then(|new_input| {
262                Arc::new(self.clone())
263                    .with_new_children(vec![new_input])
264                    .ok()
265            })
266    }
267
268    fn required_input_distribution(&self) -> Vec<Distribution> {
269        vec![Distribution::UnspecifiedDistribution]
270    }
271
272    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
273        vec![false]
274    }
275
276    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
277        vec![Some(OrderingRequirements::from(self.expr.clone()))]
278    }
279
280    fn maintains_input_order(&self) -> Vec<bool> {
281        vec![true]
282    }
283
284    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
285        vec![&self.input]
286    }
287
288    fn with_new_children(
289        self: Arc<Self>,
290        mut children: Vec<Arc<dyn ExecutionPlan>>,
291    ) -> Result<Arc<dyn ExecutionPlan>> {
292        check_if_same_properties!(self, children);
293        Ok(Arc::new(
294            SortPreservingMergeExec::new(self.expr.clone(), children.swap_remove(0))
295                .with_fetch(self.fetch),
296        ))
297    }
298
299    fn execute(
300        &self,
301        partition: usize,
302        context: Arc<TaskContext>,
303    ) -> Result<SendableRecordBatchStream> {
304        trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
305        assert_eq_or_internal_err!(
306            partition,
307            0,
308            "SortPreservingMergeExec invalid partition {partition}"
309        );
310
311        let input_partitions = self.input.output_partitioning().partition_count();
312        trace!(
313            "Number of input partitions of  SortPreservingMergeExec::execute: {input_partitions}"
314        );
315        let schema = self.schema();
316
317        let reservation =
318            MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
319                .register(&context.runtime_env().memory_pool);
320
321        match input_partitions {
322            0 => internal_err!(
323                "SortPreservingMergeExec requires at least one input partition"
324            ),
325            1 => match self.fetch {
326                Some(fetch) => {
327                    let stream = self.input.execute(0, context)?;
328                    debug!(
329                        "Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"
330                    );
331                    Ok(Box::pin(LimitStream::new(
332                        stream,
333                        0,
334                        Some(fetch),
335                        BaselineMetrics::new(&self.metrics, partition),
336                    )))
337                }
338                None => {
339                    let stream = self.input.execute(0, context);
340                    debug!(
341                        "Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"
342                    );
343                    stream
344                }
345            },
346            _ => {
347                let receivers = (0..input_partitions)
348                    .map(|partition| {
349                        let stream =
350                            self.input.execute(partition, Arc::clone(&context))?;
351                        Ok(spawn_buffered(stream, 1))
352                    })
353                    .collect::<Result<_>>()?;
354
355                debug!(
356                    "Done setting up sender-receiver for SortPreservingMergeExec::execute"
357                );
358
359                let result = StreamingMergeBuilder::new()
360                    .with_streams(receivers)
361                    .with_schema(schema)
362                    .with_expressions(&self.expr)
363                    .with_metrics(BaselineMetrics::new(&self.metrics, partition))
364                    .with_batch_size(context.session_config().batch_size())
365                    .with_fetch(self.fetch)
366                    .with_reservation(reservation)
367                    .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
368                    .build()?;
369
370                debug!(
371                    "Got stream result from SortPreservingMergeStream::new_from_receivers"
372                );
373
374                Ok(result)
375            }
376        }
377    }
378
379    fn metrics(&self) -> Option<MetricsSet> {
380        Some(self.metrics.clone_inner())
381    }
382
383    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Arc<Statistics>> {
384        self.input.partition_statistics(None)
385    }
386
387    fn supports_limit_pushdown(&self) -> bool {
388        true
389    }
390
391    /// Tries to swap the projection with its input [`SortPreservingMergeExec`].
392    /// If this is possible, it returns the new [`SortPreservingMergeExec`] whose
393    /// child is a projection. Otherwise, it returns None.
394    fn try_swapping_with_projection(
395        &self,
396        projection: &ProjectionExec,
397    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
398        // If the projection does not narrow the schema, we should not try to push it down.
399        if projection.expr().len() >= projection.input().schema().fields().len() {
400            return Ok(None);
401        }
402
403        let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
404        else {
405            return Ok(None);
406        };
407
408        Ok(Some(Arc::new(
409            SortPreservingMergeExec::new(
410                updated_exprs,
411                make_with_child(projection, self.input())?,
412            )
413            .with_fetch(self.fetch()),
414        )))
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use std::collections::HashSet;
421    use std::fmt::Formatter;
422    use std::pin::Pin;
423    use std::sync::Mutex;
424    use std::task::{Context, Poll, Waker, ready};
425    use std::time::Duration;
426
427    use super::*;
428    use crate::coalesce_partitions::CoalescePartitionsExec;
429    use crate::execution_plan::{Boundedness, EmissionType};
430    use crate::expressions::col;
431    use crate::metrics::{MetricValue, Timestamp};
432    use crate::repartition::RepartitionExec;
433    use crate::sorts::sort::SortExec;
434    use crate::stream::RecordBatchReceiverStream;
435    use crate::test::TestMemoryExec;
436    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
437    use crate::test::{self, assert_is_pending, make_partition};
438    use crate::{collect, common};
439
440    use arrow::array::{
441        ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
442        TimestampNanosecondArray,
443    };
444    use arrow::compute::SortOptions;
445    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
446    use datafusion_common::test_util::batches_to_string;
447    use datafusion_common::{assert_batches_eq, exec_err};
448    use datafusion_common_runtime::SpawnedTask;
449    use datafusion_execution::RecordBatchStream;
450    use datafusion_execution::config::SessionConfig;
451    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
452    use datafusion_physical_expr::EquivalenceProperties;
453    use datafusion_physical_expr::expressions::Column;
454    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
455    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
456
457    use futures::{FutureExt, Stream, StreamExt};
458    use insta::assert_snapshot;
459    use tokio::time::timeout;
460
461    // The number in the function is highly related to the memory limit we are testing
462    // any change of the constant should be aware of
463    fn generate_task_ctx_for_round_robin_tie_breaker(
464        target_batch_size: usize,
465    ) -> Result<Arc<TaskContext>> {
466        let runtime = RuntimeEnvBuilder::new()
467            .with_memory_limit(20_000_000, 1.0)
468            .build_arc()?;
469        let mut config = SessionConfig::new();
470        config.options_mut().execution.batch_size = target_batch_size;
471        let task_ctx = TaskContext::default()
472            .with_runtime(runtime)
473            .with_session_config(config);
474        Ok(Arc::new(task_ctx))
475    }
476
477    // The number in the function is highly related to the memory limit we are testing,
478    // any change of the constant should be aware of
479    fn generate_spm_for_round_robin_tie_breaker(
480        enable_round_robin_repartition: bool,
481    ) -> Result<Arc<SortPreservingMergeExec>> {
482        let row_size = 12500;
483        let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
484        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
485        let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
486        let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
487        let schema = rb.schema();
488
489        let rbs = std::iter::repeat_n(rb, 1024).collect::<Vec<_>>();
490        let sort = [
491            PhysicalSortExpr {
492                expr: col("b", &schema)?,
493                options: Default::default(),
494            },
495            PhysicalSortExpr {
496                expr: col("c", &schema)?,
497                options: Default::default(),
498            },
499        ]
500        .into();
501
502        let repartition_exec = RepartitionExec::try_new(
503            TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
504            Partitioning::RoundRobinBatch(2),
505        )?;
506        let spm = SortPreservingMergeExec::new(sort, Arc::new(repartition_exec))
507            .with_round_robin_repartition(enable_round_robin_repartition);
508        Ok(Arc::new(spm))
509    }
510
511    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
512    /// Any errors here could indicate unintended changes in tie breaker logic.
513    ///
514    /// Note: If you adjust constants in this test, ensure that memory usage differs
515    /// based on whether the tie breaker is enabled or disabled.
516    #[tokio::test(flavor = "multi_thread")]
517    async fn test_round_robin_tie_breaker_success() -> Result<()> {
518        let target_batch_size = 12500;
519        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(target_batch_size)?;
520        let spm = generate_spm_for_round_robin_tie_breaker(true)?;
521        let _collected = collect(spm, task_ctx).await?;
522        Ok(())
523    }
524
525    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
526    /// Any errors here could indicate unintended changes in tie breaker logic.
527    ///
528    /// Note: If you adjust constants in this test, ensure that memory usage differs
529    /// based on whether the tie breaker is enabled or disabled.
530    #[tokio::test(flavor = "multi_thread")]
531    async fn test_round_robin_tie_breaker_fail() -> Result<()> {
532        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(8192)?;
533        let spm = generate_spm_for_round_robin_tie_breaker(false)?;
534        let _err = collect(spm, task_ctx).await.unwrap_err();
535        Ok(())
536    }
537
538    #[tokio::test]
539    async fn test_merge_interleave() {
540        let task_ctx = Arc::new(TaskContext::default());
541        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
542        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
543            Some("a"),
544            Some("c"),
545            Some("e"),
546            Some("g"),
547            Some("j"),
548        ]));
549        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
550        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
551
552        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
553        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
554            Some("b"),
555            Some("d"),
556            Some("f"),
557            Some("h"),
558            Some("j"),
559        ]));
560        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
561        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
562
563        _test_merge(
564            &[vec![b1], vec![b2]],
565            &[
566                "+----+---+-------------------------------+",
567                "| a  | b | c                             |",
568                "+----+---+-------------------------------+",
569                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
570                "| 10 | b | 1970-01-01T00:00:00.000000004 |",
571                "| 2  | c | 1970-01-01T00:00:00.000000007 |",
572                "| 20 | d | 1970-01-01T00:00:00.000000006 |",
573                "| 7  | e | 1970-01-01T00:00:00.000000006 |",
574                "| 70 | f | 1970-01-01T00:00:00.000000002 |",
575                "| 9  | g | 1970-01-01T00:00:00.000000005 |",
576                "| 90 | h | 1970-01-01T00:00:00.000000002 |",
577                "| 30 | j | 1970-01-01T00:00:00.000000006 |", // input b2 before b1
578                "| 3  | j | 1970-01-01T00:00:00.000000008 |",
579                "+----+---+-------------------------------+",
580            ],
581            task_ctx,
582        )
583        .await;
584    }
585
586    #[tokio::test]
587    async fn test_merge_some_overlap() {
588        let task_ctx = Arc::new(TaskContext::default());
589        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
590        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
591            Some("a"),
592            Some("b"),
593            Some("c"),
594            Some("d"),
595            Some("e"),
596        ]));
597        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
598        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
599
600        let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
601        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
602            Some("c"),
603            Some("d"),
604            Some("e"),
605            Some("f"),
606            Some("g"),
607        ]));
608        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
609        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
610
611        _test_merge(
612            &[vec![b1], vec![b2]],
613            &[
614                "+-----+---+-------------------------------+",
615                "| a   | b | c                             |",
616                "+-----+---+-------------------------------+",
617                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
618                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
619                "| 70  | c | 1970-01-01T00:00:00.000000004 |",
620                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
621                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
622                "| 90  | d | 1970-01-01T00:00:00.000000006 |",
623                "| 30  | e | 1970-01-01T00:00:00.000000002 |",
624                "| 3   | e | 1970-01-01T00:00:00.000000008 |",
625                "| 100 | f | 1970-01-01T00:00:00.000000002 |",
626                "| 110 | g | 1970-01-01T00:00:00.000000006 |",
627                "+-----+---+-------------------------------+",
628            ],
629            task_ctx,
630        )
631        .await;
632    }
633
634    #[tokio::test]
635    async fn test_merge_no_overlap() {
636        let task_ctx = Arc::new(TaskContext::default());
637        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
638        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
639            Some("a"),
640            Some("b"),
641            Some("c"),
642            Some("d"),
643            Some("e"),
644        ]));
645        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
646        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
647
648        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
649        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
650            Some("f"),
651            Some("g"),
652            Some("h"),
653            Some("i"),
654            Some("j"),
655        ]));
656        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
657        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
658
659        _test_merge(
660            &[vec![b1], vec![b2]],
661            &[
662                "+----+---+-------------------------------+",
663                "| a  | b | c                             |",
664                "+----+---+-------------------------------+",
665                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
666                "| 2  | b | 1970-01-01T00:00:00.000000007 |",
667                "| 7  | c | 1970-01-01T00:00:00.000000006 |",
668                "| 9  | d | 1970-01-01T00:00:00.000000005 |",
669                "| 3  | e | 1970-01-01T00:00:00.000000008 |",
670                "| 10 | f | 1970-01-01T00:00:00.000000004 |",
671                "| 20 | g | 1970-01-01T00:00:00.000000006 |",
672                "| 70 | h | 1970-01-01T00:00:00.000000002 |",
673                "| 90 | i | 1970-01-01T00:00:00.000000002 |",
674                "| 30 | j | 1970-01-01T00:00:00.000000006 |",
675                "+----+---+-------------------------------+",
676            ],
677            task_ctx,
678        )
679        .await;
680    }
681
682    #[tokio::test]
683    async fn test_merge_three_partitions() {
684        let task_ctx = Arc::new(TaskContext::default());
685        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
686        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
687            Some("a"),
688            Some("b"),
689            Some("c"),
690            Some("d"),
691            Some("f"),
692        ]));
693        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
694        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
695
696        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
697        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
698            Some("e"),
699            Some("g"),
700            Some("h"),
701            Some("i"),
702            Some("j"),
703        ]));
704        let c: ArrayRef =
705            Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
706        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
707
708        let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
709        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
710            Some("f"),
711            Some("g"),
712            Some("h"),
713            Some("i"),
714            Some("j"),
715        ]));
716        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
717        let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
718
719        _test_merge(
720            &[vec![b1], vec![b2], vec![b3]],
721            &[
722                "+-----+---+-------------------------------+",
723                "| a   | b | c                             |",
724                "+-----+---+-------------------------------+",
725                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
726                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
727                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
728                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
729                "| 10  | e | 1970-01-01T00:00:00.000000040 |",
730                "| 100 | f | 1970-01-01T00:00:00.000000004 |",
731                "| 3   | f | 1970-01-01T00:00:00.000000008 |",
732                "| 200 | g | 1970-01-01T00:00:00.000000006 |",
733                "| 20  | g | 1970-01-01T00:00:00.000000060 |",
734                "| 700 | h | 1970-01-01T00:00:00.000000002 |",
735                "| 70  | h | 1970-01-01T00:00:00.000000020 |",
736                "| 900 | i | 1970-01-01T00:00:00.000000002 |",
737                "| 90  | i | 1970-01-01T00:00:00.000000020 |",
738                "| 300 | j | 1970-01-01T00:00:00.000000006 |",
739                "| 30  | j | 1970-01-01T00:00:00.000000060 |",
740                "+-----+---+-------------------------------+",
741            ],
742            task_ctx,
743        )
744        .await;
745    }
746
747    async fn _test_merge(
748        partitions: &[Vec<RecordBatch>],
749        exp: &[&str],
750        context: Arc<TaskContext>,
751    ) {
752        let schema = partitions[0][0].schema();
753        let sort = [
754            PhysicalSortExpr {
755                expr: col("b", &schema).unwrap(),
756                options: Default::default(),
757            },
758            PhysicalSortExpr {
759                expr: col("c", &schema).unwrap(),
760                options: Default::default(),
761            },
762        ]
763        .into();
764        let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
765        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
766
767        let collected = collect(merge, context).await.unwrap();
768        assert_batches_eq!(exp, collected.as_slice());
769    }
770
771    async fn sorted_merge(
772        input: Arc<dyn ExecutionPlan>,
773        sort: LexOrdering,
774        context: Arc<TaskContext>,
775    ) -> RecordBatch {
776        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
777        let mut result = collect(merge, context).await.unwrap();
778        assert_eq!(result.len(), 1);
779        result.remove(0)
780    }
781
782    async fn partition_sort(
783        input: Arc<dyn ExecutionPlan>,
784        sort: LexOrdering,
785        context: Arc<TaskContext>,
786    ) -> RecordBatch {
787        let sort_exec =
788            Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
789        sorted_merge(sort_exec, sort, context).await
790    }
791
792    async fn basic_sort(
793        src: Arc<dyn ExecutionPlan>,
794        sort: LexOrdering,
795        context: Arc<TaskContext>,
796    ) -> RecordBatch {
797        let merge = Arc::new(CoalescePartitionsExec::new(src));
798        let sort_exec = Arc::new(SortExec::new(sort, merge));
799        let mut result = collect(sort_exec, context).await.unwrap();
800        assert_eq!(result.len(), 1);
801        result.remove(0)
802    }
803
804    #[tokio::test]
805    async fn test_partition_sort() -> Result<()> {
806        let task_ctx = Arc::new(TaskContext::default());
807        let partitions = 4;
808        let csv = test::scan_partitioned(partitions);
809        let schema = csv.schema();
810
811        let sort: LexOrdering = [PhysicalSortExpr {
812            expr: col("i", &schema)?,
813            options: SortOptions {
814                descending: true,
815                nulls_first: true,
816            },
817        }]
818        .into();
819
820        let basic =
821            basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
822        let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
823
824        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
825            .unwrap()
826            .to_string();
827        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
828            .unwrap()
829            .to_string();
830
831        assert_eq!(
832            basic, partition,
833            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
834        );
835
836        Ok(())
837    }
838
839    // Split the provided record batch into multiple batch_size record batches
840    fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
841        let batches = sorted.num_rows().div_ceil(batch_size);
842
843        // Split the sorted RecordBatch into multiple
844        (0..batches)
845            .map(|batch_idx| {
846                let columns = (0..sorted.num_columns())
847                    .map(|column_idx| {
848                        let length =
849                            batch_size.min(sorted.num_rows() - batch_idx * batch_size);
850
851                        sorted
852                            .column(column_idx)
853                            .slice(batch_idx * batch_size, length)
854                    })
855                    .collect();
856
857                RecordBatch::try_new(sorted.schema(), columns).unwrap()
858            })
859            .collect()
860    }
861
862    async fn sorted_partitioned_input(
863        sort: LexOrdering,
864        sizes: &[usize],
865        context: Arc<TaskContext>,
866    ) -> Result<Arc<dyn ExecutionPlan>> {
867        let partitions = 4;
868        let csv = test::scan_partitioned(partitions);
869
870        let sorted = basic_sort(csv, sort, context).await;
871        let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
872
873        TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
874    }
875
876    #[tokio::test]
877    async fn test_partition_sort_streaming_input() -> Result<()> {
878        let task_ctx = Arc::new(TaskContext::default());
879        let schema = make_partition(11).schema();
880        let sort: LexOrdering = [PhysicalSortExpr {
881            expr: col("i", &schema)?,
882            options: Default::default(),
883        }]
884        .into();
885
886        let input =
887            sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
888                .await?;
889        let basic =
890            basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
891        let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
892
893        assert_eq!(basic.num_rows(), 1200);
894        assert_eq!(partition.num_rows(), 1200);
895
896        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
897        let partition =
898            arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
899
900        assert_eq!(basic, partition);
901
902        Ok(())
903    }
904
905    #[tokio::test]
906    async fn test_partition_sort_streaming_input_output() -> Result<()> {
907        let schema = make_partition(11).schema();
908        let sort: LexOrdering = [PhysicalSortExpr {
909            expr: col("i", &schema)?,
910            options: Default::default(),
911        }]
912        .into();
913
914        // Test streaming with default batch size
915        let task_ctx = Arc::new(TaskContext::default());
916        let input =
917            sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
918                .await?;
919        let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
920
921        // batch size of 23
922        let task_ctx = TaskContext::default()
923            .with_session_config(SessionConfig::new().with_batch_size(23));
924        let task_ctx = Arc::new(task_ctx);
925
926        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
927        let merged = collect(merge, task_ctx).await?;
928
929        assert_eq!(merged.len(), 53);
930        assert_eq!(basic.num_rows(), 1200);
931        assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
932
933        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
934        let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
935
936        assert_eq!(basic, partition);
937
938        Ok(())
939    }
940
941    #[tokio::test]
942    async fn test_nulls() {
943        let task_ctx = Arc::new(TaskContext::default());
944        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
945        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
946            None,
947            Some("a"),
948            Some("b"),
949            Some("d"),
950            Some("e"),
951        ]));
952        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
953            Some(8),
954            None,
955            Some(6),
956            None,
957            Some(4),
958        ]));
959        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
960
961        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
962        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
963            None,
964            Some("b"),
965            Some("g"),
966            Some("h"),
967            Some("i"),
968        ]));
969        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
970            Some(8),
971            None,
972            Some(5),
973            None,
974            Some(4),
975        ]));
976        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
977        let schema = b1.schema();
978
979        let sort = [
980            PhysicalSortExpr {
981                expr: col("b", &schema).unwrap(),
982                options: SortOptions {
983                    descending: false,
984                    nulls_first: true,
985                },
986            },
987            PhysicalSortExpr {
988                expr: col("c", &schema).unwrap(),
989                options: SortOptions {
990                    descending: false,
991                    nulls_first: false,
992                },
993            },
994        ]
995        .into();
996        let exec =
997            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
998        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
999
1000        let collected = collect(merge, task_ctx).await.unwrap();
1001        assert_eq!(collected.len(), 1);
1002
1003        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1004        +---+---+-------------------------------+
1005        | a | b | c                             |
1006        +---+---+-------------------------------+
1007        | 1 |   | 1970-01-01T00:00:00.000000008 |
1008        | 1 |   | 1970-01-01T00:00:00.000000008 |
1009        | 2 | a |                               |
1010        | 7 | b | 1970-01-01T00:00:00.000000006 |
1011        | 2 | b |                               |
1012        | 9 | d |                               |
1013        | 3 | e | 1970-01-01T00:00:00.000000004 |
1014        | 3 | g | 1970-01-01T00:00:00.000000005 |
1015        | 4 | h |                               |
1016        | 5 | i | 1970-01-01T00:00:00.000000004 |
1017        +---+---+-------------------------------+
1018        ");
1019    }
1020
1021    #[tokio::test]
1022    async fn test_sort_merge_single_partition_with_fetch() {
1023        let task_ctx = Arc::new(TaskContext::default());
1024        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1025        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1026        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1027        let schema = batch.schema();
1028
1029        let sort = [PhysicalSortExpr {
1030            expr: col("b", &schema).unwrap(),
1031            options: SortOptions {
1032                descending: false,
1033                nulls_first: true,
1034            },
1035        }]
1036        .into();
1037        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1038        let merge =
1039            Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1040
1041        let collected = collect(merge, task_ctx).await.unwrap();
1042        assert_eq!(collected.len(), 1);
1043
1044        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1045        +---+---+
1046        | a | b |
1047        +---+---+
1048        | 1 | a |
1049        | 2 | b |
1050        +---+---+
1051        ");
1052    }
1053
1054    #[tokio::test]
1055    async fn test_sort_merge_single_partition_without_fetch() {
1056        let task_ctx = Arc::new(TaskContext::default());
1057        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1058        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1059        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1060        let schema = batch.schema();
1061
1062        let sort = [PhysicalSortExpr {
1063            expr: col("b", &schema).unwrap(),
1064            options: SortOptions {
1065                descending: false,
1066                nulls_first: true,
1067            },
1068        }]
1069        .into();
1070        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1071        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1072
1073        let collected = collect(merge, task_ctx).await.unwrap();
1074        assert_eq!(collected.len(), 1);
1075
1076        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1077        +---+---+
1078        | a | b |
1079        +---+---+
1080        | 1 | a |
1081        | 2 | b |
1082        | 7 | c |
1083        | 9 | d |
1084        | 3 | e |
1085        +---+---+
1086        ");
1087    }
1088
1089    #[tokio::test]
1090    async fn test_async() -> Result<()> {
1091        let task_ctx = Arc::new(TaskContext::default());
1092        let schema = make_partition(11).schema();
1093        let sort: LexOrdering = [PhysicalSortExpr {
1094            expr: col("i", &schema).unwrap(),
1095            options: SortOptions::default(),
1096        }]
1097        .into();
1098
1099        let batches =
1100            sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1101                .await?;
1102
1103        let partition_count = batches.output_partitioning().partition_count();
1104        let mut streams = Vec::with_capacity(partition_count);
1105
1106        for partition in 0..partition_count {
1107            let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1108
1109            let sender = builder.tx();
1110
1111            let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1112            builder.spawn(async move {
1113                while let Some(batch) = stream.next().await {
1114                    sender.send(batch).await.unwrap();
1115                    // This causes the MergeStream to wait for more input
1116                    tokio::time::sleep(Duration::from_millis(10)).await;
1117                }
1118
1119                Ok(())
1120            });
1121
1122            streams.push(builder.build());
1123        }
1124
1125        let metrics = ExecutionPlanMetricsSet::new();
1126        let reservation =
1127            MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1128
1129        let fetch = None;
1130        let merge_stream = StreamingMergeBuilder::new()
1131            .with_streams(streams)
1132            .with_schema(batches.schema())
1133            .with_expressions(&sort)
1134            .with_metrics(BaselineMetrics::new(&metrics, 0))
1135            .with_batch_size(task_ctx.session_config().batch_size())
1136            .with_fetch(fetch)
1137            .with_reservation(reservation)
1138            .build()?;
1139
1140        let mut merged = common::collect(merge_stream).await.unwrap();
1141
1142        assert_eq!(merged.len(), 1);
1143        let merged = merged.remove(0);
1144        let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1145
1146        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1147            .unwrap()
1148            .to_string();
1149        let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1150            .unwrap()
1151            .to_string();
1152
1153        assert_eq!(
1154            basic, partition,
1155            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1156        );
1157
1158        Ok(())
1159    }
1160
1161    #[tokio::test]
1162    async fn test_merge_metrics() {
1163        let task_ctx = Arc::new(TaskContext::default());
1164        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1165        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1166        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1167
1168        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1169        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1170        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1171
1172        let schema = b1.schema();
1173        let sort = [PhysicalSortExpr {
1174            expr: col("b", &schema).unwrap(),
1175            options: Default::default(),
1176        }]
1177        .into();
1178        let exec =
1179            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1180        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1181
1182        let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1183            .await
1184            .unwrap();
1185        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1186        +----+---+
1187        | a  | b |
1188        +----+---+
1189        | 1  | a |
1190        | 10 | b |
1191        | 2  | c |
1192        | 20 | d |
1193        +----+---+
1194        ");
1195
1196        // Now, validate metrics
1197        let metrics = merge.metrics().unwrap();
1198
1199        assert_eq!(metrics.output_rows().unwrap(), 4);
1200        assert!(metrics.elapsed_compute().unwrap() > 0);
1201
1202        let mut saw_start = false;
1203        let mut saw_end = false;
1204        metrics.iter().for_each(|m| match m.value() {
1205            MetricValue::StartTimestamp(ts) => {
1206                saw_start = true;
1207                assert!(nanos_from_timestamp(ts) > 0);
1208            }
1209            MetricValue::EndTimestamp(ts) => {
1210                saw_end = true;
1211                assert!(nanos_from_timestamp(ts) > 0);
1212            }
1213            _ => {}
1214        });
1215
1216        assert!(saw_start);
1217        assert!(saw_end);
1218    }
1219
1220    fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1221        ts.value().unwrap().timestamp_nanos_opt().unwrap()
1222    }
1223
1224    #[tokio::test]
1225    async fn test_drop_cancel() -> Result<()> {
1226        let task_ctx = Arc::new(TaskContext::default());
1227        let schema =
1228            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1229
1230        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1231        let refs = blocking_exec.refs();
1232        let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1233            [PhysicalSortExpr {
1234                expr: col("a", &schema)?,
1235                options: SortOptions::default(),
1236            }]
1237            .into(),
1238            blocking_exec,
1239        ));
1240
1241        let fut = collect(sort_preserving_merge_exec, task_ctx);
1242        let mut fut = fut.boxed();
1243
1244        assert_is_pending(&mut fut);
1245        drop(fut);
1246        assert_strong_count_converges_to_zero(refs).await;
1247
1248        Ok(())
1249    }
1250
1251    #[tokio::test]
1252    async fn test_stable_sort() {
1253        let task_ctx = Arc::new(TaskContext::default());
1254
1255        // Create record batches like:
1256        // batch_number |value
1257        // -------------+------
1258        //    1         | A
1259        //    1         | B
1260        //
1261        // Ensure that the output is in the same order the batches were fed
1262        let partitions: Vec<Vec<RecordBatch>> = (0..10)
1263            .map(|batch_number| {
1264                let batch_number: Int32Array =
1265                    vec![Some(batch_number), Some(batch_number)]
1266                        .into_iter()
1267                        .collect();
1268                let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1269
1270                let batch = RecordBatch::try_from_iter(vec![
1271                    ("batch_number", Arc::new(batch_number) as ArrayRef),
1272                    ("value", Arc::new(value) as ArrayRef),
1273                ])
1274                .unwrap();
1275
1276                vec![batch]
1277            })
1278            .collect();
1279
1280        let schema = partitions[0][0].schema();
1281
1282        let sort = [PhysicalSortExpr {
1283            expr: col("value", &schema).unwrap(),
1284            options: SortOptions {
1285                descending: false,
1286                nulls_first: true,
1287            },
1288        }]
1289        .into();
1290
1291        let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1292        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1293
1294        let collected = collect(merge, task_ctx).await.unwrap();
1295        assert_eq!(collected.len(), 1);
1296
1297        // Expect the data to be sorted first by "batch_number" (because
1298        // that was the order it was fed in, even though only "value"
1299        // is in the sort key)
1300        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1301        +--------------+-------+
1302        | batch_number | value |
1303        +--------------+-------+
1304        | 0            | A     |
1305        | 1            | A     |
1306        | 2            | A     |
1307        | 3            | A     |
1308        | 4            | A     |
1309        | 5            | A     |
1310        | 6            | A     |
1311        | 7            | A     |
1312        | 8            | A     |
1313        | 9            | A     |
1314        | 0            | B     |
1315        | 1            | B     |
1316        | 2            | B     |
1317        | 3            | B     |
1318        | 4            | B     |
1319        | 5            | B     |
1320        | 6            | B     |
1321        | 7            | B     |
1322        | 8            | B     |
1323        | 9            | B     |
1324        +--------------+-------+
1325        ");
1326    }
1327
1328    #[derive(Debug)]
1329    struct CongestionState {
1330        wakers: Vec<Waker>,
1331        unpolled_partitions: HashSet<usize>,
1332    }
1333
1334    #[derive(Debug)]
1335    struct Congestion {
1336        congestion_state: Mutex<CongestionState>,
1337    }
1338
1339    impl Congestion {
1340        fn new(partition_count: usize) -> Self {
1341            Congestion {
1342                congestion_state: Mutex::new(CongestionState {
1343                    wakers: vec![],
1344                    unpolled_partitions: (0usize..partition_count).collect(),
1345                }),
1346            }
1347        }
1348
1349        fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1350            let mut state = self.congestion_state.lock().unwrap();
1351
1352            state.unpolled_partitions.remove(&partition);
1353
1354            if state.unpolled_partitions.is_empty() {
1355                state.wakers.iter().for_each(|w| w.wake_by_ref());
1356                state.wakers.clear();
1357                Poll::Ready(())
1358            } else {
1359                state.wakers.push(cx.waker().clone());
1360                Poll::Pending
1361            }
1362        }
1363    }
1364
1365    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1366    /// partition is exhausted from the start, and if it is polled more than one, it panics.
1367    #[derive(Debug, Clone)]
1368    struct CongestedExec {
1369        schema: Schema,
1370        cache: Arc<PlanProperties>,
1371        congestion: Arc<Congestion>,
1372    }
1373
1374    impl CongestedExec {
1375        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1376            let columns = schema
1377                .fields
1378                .iter()
1379                .enumerate()
1380                .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1381                .collect::<Vec<_>>();
1382            let mut eq_properties = EquivalenceProperties::new(schema);
1383            eq_properties.add_ordering(
1384                columns
1385                    .iter()
1386                    .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1387            );
1388            PlanProperties::new(
1389                eq_properties,
1390                Partitioning::Hash(columns, 3),
1391                EmissionType::Incremental,
1392                Boundedness::Unbounded {
1393                    requires_infinite_memory: false,
1394                },
1395            )
1396        }
1397    }
1398
1399    impl ExecutionPlan for CongestedExec {
1400        fn name(&self) -> &'static str {
1401            Self::static_name()
1402        }
1403        fn properties(&self) -> &Arc<PlanProperties> {
1404            &self.cache
1405        }
1406        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1407            vec![]
1408        }
1409        fn with_new_children(
1410            self: Arc<Self>,
1411            _: Vec<Arc<dyn ExecutionPlan>>,
1412        ) -> Result<Arc<dyn ExecutionPlan>> {
1413            Ok(self)
1414        }
1415        fn execute(
1416            &self,
1417            partition: usize,
1418            _context: Arc<TaskContext>,
1419        ) -> Result<SendableRecordBatchStream> {
1420            Ok(Box::pin(CongestedStream {
1421                schema: Arc::new(self.schema.clone()),
1422                none_polled_once: false,
1423                congestion: Arc::clone(&self.congestion),
1424                partition,
1425            }))
1426        }
1427    }
1428
1429    impl DisplayAs for CongestedExec {
1430        fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1431            match t {
1432                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1433                    write!(f, "CongestedExec",).unwrap()
1434                }
1435                DisplayFormatType::TreeRender => {
1436                    // TODO: collect info
1437                    write!(f, "").unwrap()
1438                }
1439            }
1440            Ok(())
1441        }
1442    }
1443
1444    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1445    /// partition is exhausted from the start, and if it is polled more than once, it panics.
1446    #[derive(Debug)]
1447    pub struct CongestedStream {
1448        schema: SchemaRef,
1449        none_polled_once: bool,
1450        congestion: Arc<Congestion>,
1451        partition: usize,
1452    }
1453
1454    impl Stream for CongestedStream {
1455        type Item = Result<RecordBatch>;
1456        fn poll_next(
1457            mut self: Pin<&mut Self>,
1458            cx: &mut Context<'_>,
1459        ) -> Poll<Option<Self::Item>> {
1460            match self.partition {
1461                0 => {
1462                    let _ = self.congestion.check_congested(self.partition, cx);
1463                    if self.none_polled_once {
1464                        panic!("Exhausted stream is polled more than once")
1465                    } else {
1466                        self.none_polled_once = true;
1467                        Poll::Ready(None)
1468                    }
1469                }
1470                _ => {
1471                    ready!(self.congestion.check_congested(self.partition, cx));
1472                    Poll::Ready(None)
1473                }
1474            }
1475        }
1476    }
1477
1478    impl RecordBatchStream for CongestedStream {
1479        fn schema(&self) -> SchemaRef {
1480            Arc::clone(&self.schema)
1481        }
1482    }
1483
1484    #[tokio::test]
1485    async fn test_spm_congestion() -> Result<()> {
1486        let task_ctx = Arc::new(TaskContext::default());
1487        let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1488        let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1489        let &partition_count = match properties.output_partitioning() {
1490            Partitioning::RoundRobinBatch(partitions) => partitions,
1491            Partitioning::Hash(_, partitions) => partitions,
1492            Partitioning::UnknownPartitioning(partitions) => partitions,
1493        };
1494        let source = CongestedExec {
1495            schema: schema.clone(),
1496            cache: Arc::new(properties),
1497            congestion: Arc::new(Congestion::new(partition_count)),
1498        };
1499        let spm = SortPreservingMergeExec::new(
1500            [PhysicalSortExpr::new_default(Arc::new(Column::new(
1501                "c1", 0,
1502            )))]
1503            .into(),
1504            Arc::new(source),
1505        );
1506        let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1507
1508        let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1509        match result {
1510            Ok(Ok(Ok(_batches))) => Ok(()),
1511            Ok(Ok(Err(e))) => Err(e),
1512            Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1513            Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1514        }
1515    }
1516
1517    #[tokio::test]
1518    async fn test_sort_merge_stops_after_error_with_buffered_rows() -> Result<()> {
1519        let task_ctx = Arc::new(TaskContext::default());
1520        let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
1521        let sort: LexOrdering = [PhysicalSortExpr::new_default(Arc::new(Column::new(
1522            "i", 0,
1523        ))
1524            as Arc<dyn PhysicalExpr>)]
1525        .into();
1526
1527        let mut stream0 = RecordBatchReceiverStream::builder(Arc::clone(&schema), 2);
1528        let tx0 = stream0.tx();
1529        let schema0 = Arc::clone(&schema);
1530        stream0.spawn(async move {
1531            let batch =
1532                RecordBatch::try_new(schema0, vec![Arc::new(Int32Array::from(vec![1]))])?;
1533            tx0.send(Ok(batch)).await.unwrap();
1534            tx0.send(exec_err!("stream failure")).await.unwrap();
1535            Ok(())
1536        });
1537
1538        let mut stream1 = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1539        let tx1 = stream1.tx();
1540        let schema1 = Arc::clone(&schema);
1541        stream1.spawn(async move {
1542            let batch =
1543                RecordBatch::try_new(schema1, vec![Arc::new(Int32Array::from(vec![2]))])?;
1544            tx1.send(Ok(batch)).await.unwrap();
1545            Ok(())
1546        });
1547
1548        let metrics = ExecutionPlanMetricsSet::new();
1549        let reservation =
1550            MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1551
1552        let mut merge_stream = StreamingMergeBuilder::new()
1553            .with_streams(vec![stream0.build(), stream1.build()])
1554            .with_schema(Arc::clone(&schema))
1555            .with_expressions(&sort)
1556            .with_metrics(BaselineMetrics::new(&metrics, 0))
1557            .with_batch_size(task_ctx.session_config().batch_size())
1558            .with_fetch(None)
1559            .with_reservation(reservation)
1560            .build()?;
1561
1562        let first = merge_stream.next().await.unwrap();
1563        assert!(first.is_err(), "expected merge stream to surface the error");
1564        assert!(
1565            merge_stream.next().await.is_none(),
1566            "merge stream yielded data after returning an error"
1567        );
1568
1569        Ok(())
1570    }
1571}