Skip to main content

datafusion_physical_plan/sorts/
sort_preserving_merge.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream.
19
20use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{ProjectionExec, make_with_child, update_ordering};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31};
32
33use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
34use datafusion_execution::TaskContext;
35use datafusion_execution::memory_pool::MemoryConsumer;
36use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
37
38use crate::execution_plan::{EvaluationType, SchedulingType};
39use log::{debug, trace};
40
41/// Sort preserving merge execution plan
42///
43/// # Overview
44///
45/// This operator implements a K-way merge. It is used to merge multiple sorted
46/// streams into a single sorted stream and is highly optimized.
47///
48/// ## Inputs:
49///
50/// 1. A list of sort expressions
51/// 2. An input plan, where each partition is sorted with respect to
52///    these sort expressions.
53///
54/// ## Output:
55///
56/// 1. A single partition that is also sorted with respect to the expressions
57///
58/// ## Diagram
59///
60/// ```text
61/// ┌─────────────────────────┐
62/// │ ┌───┬───┬───┬───┐       │
63/// │ │ A │ B │ C │ D │ ...   │──┐
64/// │ └───┴───┴───┴───┘       │  │
65/// └─────────────────────────┘  │  ┌───────────────────┐    ┌───────────────────────────────┐
66///   Stream 1                   │  │                   │    │ ┌───┬───╦═══╦───┬───╦═══╗     │
67///                              ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │
68///                              │  │                   │    │ └───┴─▲─╩═══╩───┴───╩═══╝     │
69/// ┌─────────────────────────┐  │  └───────────────────┘    └─┬─────┴───────────────────────┘
70/// │ ╔═══╦═══╗               │  │
71/// │ ║ B ║ E ║     ...       │──┘                             │
72/// │ ╚═══╩═══╝               │              Stable sort if `enable_round_robin_repartition=false`:
73/// └─────────────────────────┘              the merged stream places equal rows from stream 1
74///   Stream 2
75///
76///
77///  Input Partitions                                          Output Partition
78///    (sorted)                                                  (sorted)
79/// ```
80///
81/// # Error Handling
82///
83/// If any of the input partitions return an error, the error is propagated to
84/// the output and inputs are not polled again.
85#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87    /// Input plan with sorted partitions
88    input: Arc<dyn ExecutionPlan>,
89    /// Sort expressions
90    expr: LexOrdering,
91    /// Execution metrics
92    metrics: ExecutionPlanMetricsSet,
93    /// Optional number of rows to fetch. Stops producing rows after this fetch
94    fetch: Option<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97    /// Use round-robin selection of tied winners of loser tree
98    ///
99    /// See [`Self::with_round_robin_repartition`] for more information.
100    enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104    /// Create a new sort execution plan
105    pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106        let cache = Self::compute_properties(&input, expr.clone());
107        Self {
108            input,
109            expr,
110            metrics: ExecutionPlanMetricsSet::new(),
111            fetch: None,
112            cache,
113            enable_round_robin_repartition: true,
114        }
115    }
116
117    /// Sets the number of rows to fetch
118    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119        self.fetch = fetch;
120        self
121    }
122
123    /// Sets the selection strategy of tied winners of the loser tree algorithm
124    ///
125    /// If true (the default) equal output rows are placed in the merged stream
126    /// in round robin fashion. This approach consumes input streams at more
127    /// even rates when there are many rows with the same sort key.
128    ///
129    /// If false, equal output rows are always placed in the merged stream in
130    /// the order of the inputs, resulting in potentially slower execution but a
131    /// stable output order.
132    pub fn with_round_robin_repartition(
133        mut self,
134        enable_round_robin_repartition: bool,
135    ) -> Self {
136        self.enable_round_robin_repartition = enable_round_robin_repartition;
137        self
138    }
139
140    /// Input schema
141    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142        &self.input
143    }
144
145    /// Sort expressions
146    pub fn expr(&self) -> &LexOrdering {
147        &self.expr
148    }
149
150    /// Fetch
151    pub fn fetch(&self) -> Option<usize> {
152        self.fetch
153    }
154
155    /// Creates the cache object that stores the plan properties
156    /// such as schema, equivalence properties, ordering, partitioning, etc.
157    fn compute_properties(
158        input: &Arc<dyn ExecutionPlan>,
159        ordering: LexOrdering,
160    ) -> PlanProperties {
161        let input_partitions = input.output_partitioning().partition_count();
162        let (drive, scheduling) = if input_partitions > 1 {
163            (EvaluationType::Eager, SchedulingType::Cooperative)
164        } else {
165            (
166                input.properties().evaluation_type,
167                input.properties().scheduling_type,
168            )
169        };
170
171        let mut eq_properties = input.equivalence_properties().clone();
172        eq_properties.clear_per_partition_constants();
173        eq_properties.add_ordering(ordering);
174        PlanProperties::new(
175            eq_properties,                        // Equivalence Properties
176            Partitioning::UnknownPartitioning(1), // Output Partitioning
177            input.pipeline_behavior(),            // Pipeline Behavior
178            input.boundedness(),                  // Boundedness
179        )
180        .with_evaluation_type(drive)
181        .with_scheduling_type(scheduling)
182    }
183}
184
185impl DisplayAs for SortPreservingMergeExec {
186    fn fmt_as(
187        &self,
188        t: DisplayFormatType,
189        f: &mut std::fmt::Formatter,
190    ) -> std::fmt::Result {
191        match t {
192            DisplayFormatType::Default | DisplayFormatType::Verbose => {
193                write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
194                if let Some(fetch) = self.fetch {
195                    write!(f, ", fetch={fetch}")?;
196                };
197
198                Ok(())
199            }
200            DisplayFormatType::TreeRender => {
201                if let Some(fetch) = self.fetch {
202                    writeln!(f, "limit={fetch}")?;
203                };
204
205                for (i, e) in self.expr().iter().enumerate() {
206                    e.fmt_sql(f)?;
207                    if i != self.expr().len() - 1 {
208                        write!(f, ", ")?;
209                    }
210                }
211
212                Ok(())
213            }
214        }
215    }
216}
217
218impl ExecutionPlan for SortPreservingMergeExec {
219    fn name(&self) -> &'static str {
220        "SortPreservingMergeExec"
221    }
222
223    /// Return a reference to Any that can be used for downcasting
224    fn as_any(&self) -> &dyn Any {
225        self
226    }
227
228    fn properties(&self) -> &PlanProperties {
229        &self.cache
230    }
231
232    fn fetch(&self) -> Option<usize> {
233        self.fetch
234    }
235
236    /// Sets the number of rows to fetch
237    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
238        Some(Arc::new(Self {
239            input: Arc::clone(&self.input),
240            expr: self.expr.clone(),
241            metrics: self.metrics.clone(),
242            fetch: limit,
243            cache: self.cache.clone(),
244            enable_round_robin_repartition: true,
245        }))
246    }
247
248    fn required_input_distribution(&self) -> Vec<Distribution> {
249        vec![Distribution::UnspecifiedDistribution]
250    }
251
252    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
253        vec![false]
254    }
255
256    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
257        vec![Some(OrderingRequirements::from(self.expr.clone()))]
258    }
259
260    fn maintains_input_order(&self) -> Vec<bool> {
261        vec![true]
262    }
263
264    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
265        vec![&self.input]
266    }
267
268    fn with_new_children(
269        self: Arc<Self>,
270        children: Vec<Arc<dyn ExecutionPlan>>,
271    ) -> Result<Arc<dyn ExecutionPlan>> {
272        Ok(Arc::new(
273            SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
274                .with_fetch(self.fetch),
275        ))
276    }
277
278    fn execute(
279        &self,
280        partition: usize,
281        context: Arc<TaskContext>,
282    ) -> Result<SendableRecordBatchStream> {
283        trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
284        assert_eq_or_internal_err!(
285            partition,
286            0,
287            "SortPreservingMergeExec invalid partition {partition}"
288        );
289
290        let input_partitions = self.input.output_partitioning().partition_count();
291        trace!(
292            "Number of input partitions of  SortPreservingMergeExec::execute: {input_partitions}"
293        );
294        let schema = self.schema();
295
296        let reservation =
297            MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
298                .register(&context.runtime_env().memory_pool);
299
300        match input_partitions {
301            0 => internal_err!(
302                "SortPreservingMergeExec requires at least one input partition"
303            ),
304            1 => match self.fetch {
305                Some(fetch) => {
306                    let stream = self.input.execute(0, context)?;
307                    debug!(
308                        "Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"
309                    );
310                    Ok(Box::pin(LimitStream::new(
311                        stream,
312                        0,
313                        Some(fetch),
314                        BaselineMetrics::new(&self.metrics, partition),
315                    )))
316                }
317                None => {
318                    let stream = self.input.execute(0, context);
319                    debug!(
320                        "Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"
321                    );
322                    stream
323                }
324            },
325            _ => {
326                let receivers = (0..input_partitions)
327                    .map(|partition| {
328                        let stream =
329                            self.input.execute(partition, Arc::clone(&context))?;
330                        Ok(spawn_buffered(stream, 1))
331                    })
332                    .collect::<Result<_>>()?;
333
334                debug!(
335                    "Done setting up sender-receiver for SortPreservingMergeExec::execute"
336                );
337
338                let result = StreamingMergeBuilder::new()
339                    .with_streams(receivers)
340                    .with_schema(schema)
341                    .with_expressions(&self.expr)
342                    .with_metrics(BaselineMetrics::new(&self.metrics, partition))
343                    .with_batch_size(context.session_config().batch_size())
344                    .with_fetch(self.fetch)
345                    .with_reservation(reservation)
346                    .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
347                    .build()?;
348
349                debug!(
350                    "Got stream result from SortPreservingMergeStream::new_from_receivers"
351                );
352
353                Ok(result)
354            }
355        }
356    }
357
358    fn metrics(&self) -> Option<MetricsSet> {
359        Some(self.metrics.clone_inner())
360    }
361
362    fn statistics(&self) -> Result<Statistics> {
363        self.input.partition_statistics(None)
364    }
365
366    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
367        self.input.partition_statistics(None)
368    }
369
370    fn supports_limit_pushdown(&self) -> bool {
371        true
372    }
373
374    /// Tries to swap the projection with its input [`SortPreservingMergeExec`].
375    /// If this is possible, it returns the new [`SortPreservingMergeExec`] whose
376    /// child is a projection. Otherwise, it returns None.
377    fn try_swapping_with_projection(
378        &self,
379        projection: &ProjectionExec,
380    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
381        // If the projection does not narrow the schema, we should not try to push it down.
382        if projection.expr().len() >= projection.input().schema().fields().len() {
383            return Ok(None);
384        }
385
386        let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
387        else {
388            return Ok(None);
389        };
390
391        Ok(Some(Arc::new(
392            SortPreservingMergeExec::new(
393                updated_exprs,
394                make_with_child(projection, self.input())?,
395            )
396            .with_fetch(self.fetch()),
397        )))
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use std::collections::HashSet;
404    use std::fmt::Formatter;
405    use std::pin::Pin;
406    use std::sync::Mutex;
407    use std::task::{Context, Poll, Waker, ready};
408    use std::time::Duration;
409
410    use super::*;
411    use crate::coalesce_batches::CoalesceBatchesExec;
412    use crate::coalesce_partitions::CoalescePartitionsExec;
413    use crate::execution_plan::{Boundedness, EmissionType};
414    use crate::expressions::col;
415    use crate::metrics::{MetricValue, Timestamp};
416    use crate::repartition::RepartitionExec;
417    use crate::sorts::sort::SortExec;
418    use crate::stream::RecordBatchReceiverStream;
419    use crate::test::TestMemoryExec;
420    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
421    use crate::test::{self, assert_is_pending, make_partition};
422    use crate::{collect, common};
423
424    use arrow::array::{
425        ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
426        TimestampNanosecondArray,
427    };
428    use arrow::compute::SortOptions;
429    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
430    use datafusion_common::test_util::batches_to_string;
431    use datafusion_common::{assert_batches_eq, exec_err};
432    use datafusion_common_runtime::SpawnedTask;
433    use datafusion_execution::RecordBatchStream;
434    use datafusion_execution::config::SessionConfig;
435    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
436    use datafusion_physical_expr::EquivalenceProperties;
437    use datafusion_physical_expr::expressions::Column;
438    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
439    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
440
441    use futures::{FutureExt, Stream, StreamExt};
442    use insta::assert_snapshot;
443    use tokio::time::timeout;
444
445    // The number in the function is highly related to the memory limit we are testing
446    // any change of the constant should be aware of
447    fn generate_task_ctx_for_round_robin_tie_breaker() -> Result<Arc<TaskContext>> {
448        let runtime = RuntimeEnvBuilder::new()
449            .with_memory_limit(20_000_000, 1.0)
450            .build_arc()?;
451        let config = SessionConfig::new();
452        let task_ctx = TaskContext::default()
453            .with_runtime(runtime)
454            .with_session_config(config);
455        Ok(Arc::new(task_ctx))
456    }
457    // The number in the function is highly related to the memory limit we are testing,
458    // any change of the constant should be aware of
459    fn generate_spm_for_round_robin_tie_breaker(
460        enable_round_robin_repartition: bool,
461    ) -> Result<Arc<SortPreservingMergeExec>> {
462        let target_batch_size = 12500;
463        let row_size = 12500;
464        let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
465        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
466        let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
467        let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
468
469        let rbs = (0..1024).map(|_| rb.clone()).collect::<Vec<_>>();
470
471        let schema = rb.schema();
472        let sort = [
473            PhysicalSortExpr {
474                expr: col("b", &schema)?,
475                options: Default::default(),
476            },
477            PhysicalSortExpr {
478                expr: col("c", &schema)?,
479                options: Default::default(),
480            },
481        ]
482        .into();
483
484        let repartition_exec = RepartitionExec::try_new(
485            TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
486            Partitioning::RoundRobinBatch(2),
487        )?;
488        let coalesce_batches_exec =
489            CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size);
490        let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec))
491            .with_round_robin_repartition(enable_round_robin_repartition);
492        Ok(Arc::new(spm))
493    }
494
495    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
496    /// Any errors here could indicate unintended changes in tie breaker logic.
497    ///
498    /// Note: If you adjust constants in this test, ensure that memory usage differs
499    /// based on whether the tie breaker is enabled or disabled.
500    #[tokio::test(flavor = "multi_thread")]
501    async fn test_round_robin_tie_breaker_success() -> Result<()> {
502        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
503        let spm = generate_spm_for_round_robin_tie_breaker(true)?;
504        let _collected = collect(spm, task_ctx).await?;
505        Ok(())
506    }
507
508    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
509    /// Any errors here could indicate unintended changes in tie breaker logic.
510    ///
511    /// Note: If you adjust constants in this test, ensure that memory usage differs
512    /// based on whether the tie breaker is enabled or disabled.
513    #[tokio::test(flavor = "multi_thread")]
514    async fn test_round_robin_tie_breaker_fail() -> Result<()> {
515        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
516        let spm = generate_spm_for_round_robin_tie_breaker(false)?;
517        let _err = collect(spm, task_ctx).await.unwrap_err();
518        Ok(())
519    }
520
521    #[tokio::test]
522    async fn test_merge_interleave() {
523        let task_ctx = Arc::new(TaskContext::default());
524        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
525        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
526            Some("a"),
527            Some("c"),
528            Some("e"),
529            Some("g"),
530            Some("j"),
531        ]));
532        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
533        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
534
535        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
536        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
537            Some("b"),
538            Some("d"),
539            Some("f"),
540            Some("h"),
541            Some("j"),
542        ]));
543        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
544        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
545
546        _test_merge(
547            &[vec![b1], vec![b2]],
548            &[
549                "+----+---+-------------------------------+",
550                "| a  | b | c                             |",
551                "+----+---+-------------------------------+",
552                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
553                "| 10 | b | 1970-01-01T00:00:00.000000004 |",
554                "| 2  | c | 1970-01-01T00:00:00.000000007 |",
555                "| 20 | d | 1970-01-01T00:00:00.000000006 |",
556                "| 7  | e | 1970-01-01T00:00:00.000000006 |",
557                "| 70 | f | 1970-01-01T00:00:00.000000002 |",
558                "| 9  | g | 1970-01-01T00:00:00.000000005 |",
559                "| 90 | h | 1970-01-01T00:00:00.000000002 |",
560                "| 30 | j | 1970-01-01T00:00:00.000000006 |", // input b2 before b1
561                "| 3  | j | 1970-01-01T00:00:00.000000008 |",
562                "+----+---+-------------------------------+",
563            ],
564            task_ctx,
565        )
566        .await;
567    }
568
569    #[tokio::test]
570    async fn test_merge_some_overlap() {
571        let task_ctx = Arc::new(TaskContext::default());
572        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
573        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
574            Some("a"),
575            Some("b"),
576            Some("c"),
577            Some("d"),
578            Some("e"),
579        ]));
580        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
581        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
582
583        let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
584        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
585            Some("c"),
586            Some("d"),
587            Some("e"),
588            Some("f"),
589            Some("g"),
590        ]));
591        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
592        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
593
594        _test_merge(
595            &[vec![b1], vec![b2]],
596            &[
597                "+-----+---+-------------------------------+",
598                "| a   | b | c                             |",
599                "+-----+---+-------------------------------+",
600                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
601                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
602                "| 70  | c | 1970-01-01T00:00:00.000000004 |",
603                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
604                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
605                "| 90  | d | 1970-01-01T00:00:00.000000006 |",
606                "| 30  | e | 1970-01-01T00:00:00.000000002 |",
607                "| 3   | e | 1970-01-01T00:00:00.000000008 |",
608                "| 100 | f | 1970-01-01T00:00:00.000000002 |",
609                "| 110 | g | 1970-01-01T00:00:00.000000006 |",
610                "+-----+---+-------------------------------+",
611            ],
612            task_ctx,
613        )
614        .await;
615    }
616
617    #[tokio::test]
618    async fn test_merge_no_overlap() {
619        let task_ctx = Arc::new(TaskContext::default());
620        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
621        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
622            Some("a"),
623            Some("b"),
624            Some("c"),
625            Some("d"),
626            Some("e"),
627        ]));
628        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
629        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
630
631        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
632        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
633            Some("f"),
634            Some("g"),
635            Some("h"),
636            Some("i"),
637            Some("j"),
638        ]));
639        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
640        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
641
642        _test_merge(
643            &[vec![b1], vec![b2]],
644            &[
645                "+----+---+-------------------------------+",
646                "| a  | b | c                             |",
647                "+----+---+-------------------------------+",
648                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
649                "| 2  | b | 1970-01-01T00:00:00.000000007 |",
650                "| 7  | c | 1970-01-01T00:00:00.000000006 |",
651                "| 9  | d | 1970-01-01T00:00:00.000000005 |",
652                "| 3  | e | 1970-01-01T00:00:00.000000008 |",
653                "| 10 | f | 1970-01-01T00:00:00.000000004 |",
654                "| 20 | g | 1970-01-01T00:00:00.000000006 |",
655                "| 70 | h | 1970-01-01T00:00:00.000000002 |",
656                "| 90 | i | 1970-01-01T00:00:00.000000002 |",
657                "| 30 | j | 1970-01-01T00:00:00.000000006 |",
658                "+----+---+-------------------------------+",
659            ],
660            task_ctx,
661        )
662        .await;
663    }
664
665    #[tokio::test]
666    async fn test_merge_three_partitions() {
667        let task_ctx = Arc::new(TaskContext::default());
668        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
669        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
670            Some("a"),
671            Some("b"),
672            Some("c"),
673            Some("d"),
674            Some("f"),
675        ]));
676        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
677        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
678
679        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
680        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
681            Some("e"),
682            Some("g"),
683            Some("h"),
684            Some("i"),
685            Some("j"),
686        ]));
687        let c: ArrayRef =
688            Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
689        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
690
691        let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
692        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
693            Some("f"),
694            Some("g"),
695            Some("h"),
696            Some("i"),
697            Some("j"),
698        ]));
699        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
700        let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
701
702        _test_merge(
703            &[vec![b1], vec![b2], vec![b3]],
704            &[
705                "+-----+---+-------------------------------+",
706                "| a   | b | c                             |",
707                "+-----+---+-------------------------------+",
708                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
709                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
710                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
711                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
712                "| 10  | e | 1970-01-01T00:00:00.000000040 |",
713                "| 100 | f | 1970-01-01T00:00:00.000000004 |",
714                "| 3   | f | 1970-01-01T00:00:00.000000008 |",
715                "| 200 | g | 1970-01-01T00:00:00.000000006 |",
716                "| 20  | g | 1970-01-01T00:00:00.000000060 |",
717                "| 700 | h | 1970-01-01T00:00:00.000000002 |",
718                "| 70  | h | 1970-01-01T00:00:00.000000020 |",
719                "| 900 | i | 1970-01-01T00:00:00.000000002 |",
720                "| 90  | i | 1970-01-01T00:00:00.000000020 |",
721                "| 300 | j | 1970-01-01T00:00:00.000000006 |",
722                "| 30  | j | 1970-01-01T00:00:00.000000060 |",
723                "+-----+---+-------------------------------+",
724            ],
725            task_ctx,
726        )
727        .await;
728    }
729
730    async fn _test_merge(
731        partitions: &[Vec<RecordBatch>],
732        exp: &[&str],
733        context: Arc<TaskContext>,
734    ) {
735        let schema = partitions[0][0].schema();
736        let sort = [
737            PhysicalSortExpr {
738                expr: col("b", &schema).unwrap(),
739                options: Default::default(),
740            },
741            PhysicalSortExpr {
742                expr: col("c", &schema).unwrap(),
743                options: Default::default(),
744            },
745        ]
746        .into();
747        let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
748        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
749
750        let collected = collect(merge, context).await.unwrap();
751        assert_batches_eq!(exp, collected.as_slice());
752    }
753
754    async fn sorted_merge(
755        input: Arc<dyn ExecutionPlan>,
756        sort: LexOrdering,
757        context: Arc<TaskContext>,
758    ) -> RecordBatch {
759        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
760        let mut result = collect(merge, context).await.unwrap();
761        assert_eq!(result.len(), 1);
762        result.remove(0)
763    }
764
765    async fn partition_sort(
766        input: Arc<dyn ExecutionPlan>,
767        sort: LexOrdering,
768        context: Arc<TaskContext>,
769    ) -> RecordBatch {
770        let sort_exec =
771            Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
772        sorted_merge(sort_exec, sort, context).await
773    }
774
775    async fn basic_sort(
776        src: Arc<dyn ExecutionPlan>,
777        sort: LexOrdering,
778        context: Arc<TaskContext>,
779    ) -> RecordBatch {
780        let merge = Arc::new(CoalescePartitionsExec::new(src));
781        let sort_exec = Arc::new(SortExec::new(sort, merge));
782        let mut result = collect(sort_exec, context).await.unwrap();
783        assert_eq!(result.len(), 1);
784        result.remove(0)
785    }
786
787    #[tokio::test]
788    async fn test_partition_sort() -> Result<()> {
789        let task_ctx = Arc::new(TaskContext::default());
790        let partitions = 4;
791        let csv = test::scan_partitioned(partitions);
792        let schema = csv.schema();
793
794        let sort: LexOrdering = [PhysicalSortExpr {
795            expr: col("i", &schema)?,
796            options: SortOptions {
797                descending: true,
798                nulls_first: true,
799            },
800        }]
801        .into();
802
803        let basic =
804            basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
805        let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
806
807        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
808            .unwrap()
809            .to_string();
810        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
811            .unwrap()
812            .to_string();
813
814        assert_eq!(
815            basic, partition,
816            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
817        );
818
819        Ok(())
820    }
821
822    // Split the provided record batch into multiple batch_size record batches
823    fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
824        let batches = sorted.num_rows().div_ceil(batch_size);
825
826        // Split the sorted RecordBatch into multiple
827        (0..batches)
828            .map(|batch_idx| {
829                let columns = (0..sorted.num_columns())
830                    .map(|column_idx| {
831                        let length =
832                            batch_size.min(sorted.num_rows() - batch_idx * batch_size);
833
834                        sorted
835                            .column(column_idx)
836                            .slice(batch_idx * batch_size, length)
837                    })
838                    .collect();
839
840                RecordBatch::try_new(sorted.schema(), columns).unwrap()
841            })
842            .collect()
843    }
844
845    async fn sorted_partitioned_input(
846        sort: LexOrdering,
847        sizes: &[usize],
848        context: Arc<TaskContext>,
849    ) -> Result<Arc<dyn ExecutionPlan>> {
850        let partitions = 4;
851        let csv = test::scan_partitioned(partitions);
852
853        let sorted = basic_sort(csv, sort, context).await;
854        let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
855
856        TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
857    }
858
859    #[tokio::test]
860    async fn test_partition_sort_streaming_input() -> Result<()> {
861        let task_ctx = Arc::new(TaskContext::default());
862        let schema = make_partition(11).schema();
863        let sort: LexOrdering = [PhysicalSortExpr {
864            expr: col("i", &schema)?,
865            options: Default::default(),
866        }]
867        .into();
868
869        let input =
870            sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
871                .await?;
872        let basic =
873            basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
874        let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
875
876        assert_eq!(basic.num_rows(), 1200);
877        assert_eq!(partition.num_rows(), 1200);
878
879        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
880        let partition =
881            arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
882
883        assert_eq!(basic, partition);
884
885        Ok(())
886    }
887
888    #[tokio::test]
889    async fn test_partition_sort_streaming_input_output() -> Result<()> {
890        let schema = make_partition(11).schema();
891        let sort: LexOrdering = [PhysicalSortExpr {
892            expr: col("i", &schema)?,
893            options: Default::default(),
894        }]
895        .into();
896
897        // Test streaming with default batch size
898        let task_ctx = Arc::new(TaskContext::default());
899        let input =
900            sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
901                .await?;
902        let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
903
904        // batch size of 23
905        let task_ctx = TaskContext::default()
906            .with_session_config(SessionConfig::new().with_batch_size(23));
907        let task_ctx = Arc::new(task_ctx);
908
909        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
910        let merged = collect(merge, task_ctx).await?;
911
912        assert_eq!(merged.len(), 53);
913        assert_eq!(basic.num_rows(), 1200);
914        assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
915
916        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
917        let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
918
919        assert_eq!(basic, partition);
920
921        Ok(())
922    }
923
924    #[tokio::test]
925    async fn test_nulls() {
926        let task_ctx = Arc::new(TaskContext::default());
927        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
928        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
929            None,
930            Some("a"),
931            Some("b"),
932            Some("d"),
933            Some("e"),
934        ]));
935        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
936            Some(8),
937            None,
938            Some(6),
939            None,
940            Some(4),
941        ]));
942        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
943
944        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
945        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
946            None,
947            Some("b"),
948            Some("g"),
949            Some("h"),
950            Some("i"),
951        ]));
952        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
953            Some(8),
954            None,
955            Some(5),
956            None,
957            Some(4),
958        ]));
959        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
960        let schema = b1.schema();
961
962        let sort = [
963            PhysicalSortExpr {
964                expr: col("b", &schema).unwrap(),
965                options: SortOptions {
966                    descending: false,
967                    nulls_first: true,
968                },
969            },
970            PhysicalSortExpr {
971                expr: col("c", &schema).unwrap(),
972                options: SortOptions {
973                    descending: false,
974                    nulls_first: false,
975                },
976            },
977        ]
978        .into();
979        let exec =
980            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
981        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
982
983        let collected = collect(merge, task_ctx).await.unwrap();
984        assert_eq!(collected.len(), 1);
985
986        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
987        +---+---+-------------------------------+
988        | a | b | c                             |
989        +---+---+-------------------------------+
990        | 1 |   | 1970-01-01T00:00:00.000000008 |
991        | 1 |   | 1970-01-01T00:00:00.000000008 |
992        | 2 | a |                               |
993        | 7 | b | 1970-01-01T00:00:00.000000006 |
994        | 2 | b |                               |
995        | 9 | d |                               |
996        | 3 | e | 1970-01-01T00:00:00.000000004 |
997        | 3 | g | 1970-01-01T00:00:00.000000005 |
998        | 4 | h |                               |
999        | 5 | i | 1970-01-01T00:00:00.000000004 |
1000        +---+---+-------------------------------+
1001        ");
1002    }
1003
1004    #[tokio::test]
1005    async fn test_sort_merge_single_partition_with_fetch() {
1006        let task_ctx = Arc::new(TaskContext::default());
1007        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1008        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1009        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1010        let schema = batch.schema();
1011
1012        let sort = [PhysicalSortExpr {
1013            expr: col("b", &schema).unwrap(),
1014            options: SortOptions {
1015                descending: false,
1016                nulls_first: true,
1017            },
1018        }]
1019        .into();
1020        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1021        let merge =
1022            Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1023
1024        let collected = collect(merge, task_ctx).await.unwrap();
1025        assert_eq!(collected.len(), 1);
1026
1027        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1028        +---+---+
1029        | a | b |
1030        +---+---+
1031        | 1 | a |
1032        | 2 | b |
1033        +---+---+
1034        ");
1035    }
1036
1037    #[tokio::test]
1038    async fn test_sort_merge_single_partition_without_fetch() {
1039        let task_ctx = Arc::new(TaskContext::default());
1040        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1041        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1042        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1043        let schema = batch.schema();
1044
1045        let sort = [PhysicalSortExpr {
1046            expr: col("b", &schema).unwrap(),
1047            options: SortOptions {
1048                descending: false,
1049                nulls_first: true,
1050            },
1051        }]
1052        .into();
1053        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1054        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1055
1056        let collected = collect(merge, task_ctx).await.unwrap();
1057        assert_eq!(collected.len(), 1);
1058
1059        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1060        +---+---+
1061        | a | b |
1062        +---+---+
1063        | 1 | a |
1064        | 2 | b |
1065        | 7 | c |
1066        | 9 | d |
1067        | 3 | e |
1068        +---+---+
1069        ");
1070    }
1071
1072    #[tokio::test]
1073    async fn test_async() -> Result<()> {
1074        let task_ctx = Arc::new(TaskContext::default());
1075        let schema = make_partition(11).schema();
1076        let sort: LexOrdering = [PhysicalSortExpr {
1077            expr: col("i", &schema).unwrap(),
1078            options: SortOptions::default(),
1079        }]
1080        .into();
1081
1082        let batches =
1083            sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1084                .await?;
1085
1086        let partition_count = batches.output_partitioning().partition_count();
1087        let mut streams = Vec::with_capacity(partition_count);
1088
1089        for partition in 0..partition_count {
1090            let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1091
1092            let sender = builder.tx();
1093
1094            let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1095            builder.spawn(async move {
1096                while let Some(batch) = stream.next().await {
1097                    sender.send(batch).await.unwrap();
1098                    // This causes the MergeStream to wait for more input
1099                    tokio::time::sleep(Duration::from_millis(10)).await;
1100                }
1101
1102                Ok(())
1103            });
1104
1105            streams.push(builder.build());
1106        }
1107
1108        let metrics = ExecutionPlanMetricsSet::new();
1109        let reservation =
1110            MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1111
1112        let fetch = None;
1113        let merge_stream = StreamingMergeBuilder::new()
1114            .with_streams(streams)
1115            .with_schema(batches.schema())
1116            .with_expressions(&sort)
1117            .with_metrics(BaselineMetrics::new(&metrics, 0))
1118            .with_batch_size(task_ctx.session_config().batch_size())
1119            .with_fetch(fetch)
1120            .with_reservation(reservation)
1121            .build()?;
1122
1123        let mut merged = common::collect(merge_stream).await.unwrap();
1124
1125        assert_eq!(merged.len(), 1);
1126        let merged = merged.remove(0);
1127        let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1128
1129        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1130            .unwrap()
1131            .to_string();
1132        let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1133            .unwrap()
1134            .to_string();
1135
1136        assert_eq!(
1137            basic, partition,
1138            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1139        );
1140
1141        Ok(())
1142    }
1143
1144    #[tokio::test]
1145    async fn test_merge_metrics() {
1146        let task_ctx = Arc::new(TaskContext::default());
1147        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1148        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1149        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1150
1151        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1152        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1153        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1154
1155        let schema = b1.schema();
1156        let sort = [PhysicalSortExpr {
1157            expr: col("b", &schema).unwrap(),
1158            options: Default::default(),
1159        }]
1160        .into();
1161        let exec =
1162            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1163        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1164
1165        let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1166            .await
1167            .unwrap();
1168        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1169        +----+---+
1170        | a  | b |
1171        +----+---+
1172        | 1  | a |
1173        | 10 | b |
1174        | 2  | c |
1175        | 20 | d |
1176        +----+---+
1177        ");
1178
1179        // Now, validate metrics
1180        let metrics = merge.metrics().unwrap();
1181
1182        assert_eq!(metrics.output_rows().unwrap(), 4);
1183        assert!(metrics.elapsed_compute().unwrap() > 0);
1184
1185        let mut saw_start = false;
1186        let mut saw_end = false;
1187        metrics.iter().for_each(|m| match m.value() {
1188            MetricValue::StartTimestamp(ts) => {
1189                saw_start = true;
1190                assert!(nanos_from_timestamp(ts) > 0);
1191            }
1192            MetricValue::EndTimestamp(ts) => {
1193                saw_end = true;
1194                assert!(nanos_from_timestamp(ts) > 0);
1195            }
1196            _ => {}
1197        });
1198
1199        assert!(saw_start);
1200        assert!(saw_end);
1201    }
1202
1203    fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1204        ts.value().unwrap().timestamp_nanos_opt().unwrap()
1205    }
1206
1207    #[tokio::test]
1208    async fn test_drop_cancel() -> Result<()> {
1209        let task_ctx = Arc::new(TaskContext::default());
1210        let schema =
1211            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1212
1213        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1214        let refs = blocking_exec.refs();
1215        let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1216            [PhysicalSortExpr {
1217                expr: col("a", &schema)?,
1218                options: SortOptions::default(),
1219            }]
1220            .into(),
1221            blocking_exec,
1222        ));
1223
1224        let fut = collect(sort_preserving_merge_exec, task_ctx);
1225        let mut fut = fut.boxed();
1226
1227        assert_is_pending(&mut fut);
1228        drop(fut);
1229        assert_strong_count_converges_to_zero(refs).await;
1230
1231        Ok(())
1232    }
1233
1234    #[tokio::test]
1235    async fn test_stable_sort() {
1236        let task_ctx = Arc::new(TaskContext::default());
1237
1238        // Create record batches like:
1239        // batch_number |value
1240        // -------------+------
1241        //    1         | A
1242        //    1         | B
1243        //
1244        // Ensure that the output is in the same order the batches were fed
1245        let partitions: Vec<Vec<RecordBatch>> = (0..10)
1246            .map(|batch_number| {
1247                let batch_number: Int32Array =
1248                    vec![Some(batch_number), Some(batch_number)]
1249                        .into_iter()
1250                        .collect();
1251                let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1252
1253                let batch = RecordBatch::try_from_iter(vec![
1254                    ("batch_number", Arc::new(batch_number) as ArrayRef),
1255                    ("value", Arc::new(value) as ArrayRef),
1256                ])
1257                .unwrap();
1258
1259                vec![batch]
1260            })
1261            .collect();
1262
1263        let schema = partitions[0][0].schema();
1264
1265        let sort = [PhysicalSortExpr {
1266            expr: col("value", &schema).unwrap(),
1267            options: SortOptions {
1268                descending: false,
1269                nulls_first: true,
1270            },
1271        }]
1272        .into();
1273
1274        let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1275        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1276
1277        let collected = collect(merge, task_ctx).await.unwrap();
1278        assert_eq!(collected.len(), 1);
1279
1280        // Expect the data to be sorted first by "batch_number" (because
1281        // that was the order it was fed in, even though only "value"
1282        // is in the sort key)
1283        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1284        +--------------+-------+
1285        | batch_number | value |
1286        +--------------+-------+
1287        | 0            | A     |
1288        | 1            | A     |
1289        | 2            | A     |
1290        | 3            | A     |
1291        | 4            | A     |
1292        | 5            | A     |
1293        | 6            | A     |
1294        | 7            | A     |
1295        | 8            | A     |
1296        | 9            | A     |
1297        | 0            | B     |
1298        | 1            | B     |
1299        | 2            | B     |
1300        | 3            | B     |
1301        | 4            | B     |
1302        | 5            | B     |
1303        | 6            | B     |
1304        | 7            | B     |
1305        | 8            | B     |
1306        | 9            | B     |
1307        +--------------+-------+
1308        ");
1309    }
1310
1311    #[derive(Debug)]
1312    struct CongestionState {
1313        wakers: Vec<Waker>,
1314        unpolled_partitions: HashSet<usize>,
1315    }
1316
1317    #[derive(Debug)]
1318    struct Congestion {
1319        congestion_state: Mutex<CongestionState>,
1320    }
1321
1322    impl Congestion {
1323        fn new(partition_count: usize) -> Self {
1324            Congestion {
1325                congestion_state: Mutex::new(CongestionState {
1326                    wakers: vec![],
1327                    unpolled_partitions: (0usize..partition_count).collect(),
1328                }),
1329            }
1330        }
1331
1332        fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1333            let mut state = self.congestion_state.lock().unwrap();
1334
1335            state.unpolled_partitions.remove(&partition);
1336
1337            if state.unpolled_partitions.is_empty() {
1338                state.wakers.iter().for_each(|w| w.wake_by_ref());
1339                state.wakers.clear();
1340                Poll::Ready(())
1341            } else {
1342                state.wakers.push(cx.waker().clone());
1343                Poll::Pending
1344            }
1345        }
1346    }
1347
1348    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1349    /// partition is exhausted from the start, and if it is polled more than one, it panics.
1350    #[derive(Debug, Clone)]
1351    struct CongestedExec {
1352        schema: Schema,
1353        cache: PlanProperties,
1354        congestion: Arc<Congestion>,
1355    }
1356
1357    impl CongestedExec {
1358        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1359            let columns = schema
1360                .fields
1361                .iter()
1362                .enumerate()
1363                .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1364                .collect::<Vec<_>>();
1365            let mut eq_properties = EquivalenceProperties::new(schema);
1366            eq_properties.add_ordering(
1367                columns
1368                    .iter()
1369                    .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1370            );
1371            PlanProperties::new(
1372                eq_properties,
1373                Partitioning::Hash(columns, 3),
1374                EmissionType::Incremental,
1375                Boundedness::Unbounded {
1376                    requires_infinite_memory: false,
1377                },
1378            )
1379        }
1380    }
1381
1382    impl ExecutionPlan for CongestedExec {
1383        fn name(&self) -> &'static str {
1384            Self::static_name()
1385        }
1386        fn as_any(&self) -> &dyn Any {
1387            self
1388        }
1389        fn properties(&self) -> &PlanProperties {
1390            &self.cache
1391        }
1392        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1393            vec![]
1394        }
1395        fn with_new_children(
1396            self: Arc<Self>,
1397            _: Vec<Arc<dyn ExecutionPlan>>,
1398        ) -> Result<Arc<dyn ExecutionPlan>> {
1399            Ok(self)
1400        }
1401        fn execute(
1402            &self,
1403            partition: usize,
1404            _context: Arc<TaskContext>,
1405        ) -> Result<SendableRecordBatchStream> {
1406            Ok(Box::pin(CongestedStream {
1407                schema: Arc::new(self.schema.clone()),
1408                none_polled_once: false,
1409                congestion: Arc::clone(&self.congestion),
1410                partition,
1411            }))
1412        }
1413    }
1414
1415    impl DisplayAs for CongestedExec {
1416        fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1417            match t {
1418                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1419                    write!(f, "CongestedExec",).unwrap()
1420                }
1421                DisplayFormatType::TreeRender => {
1422                    // TODO: collect info
1423                    write!(f, "").unwrap()
1424                }
1425            }
1426            Ok(())
1427        }
1428    }
1429
1430    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1431    /// partition is exhausted from the start, and if it is polled more than once, it panics.
1432    #[derive(Debug)]
1433    pub struct CongestedStream {
1434        schema: SchemaRef,
1435        none_polled_once: bool,
1436        congestion: Arc<Congestion>,
1437        partition: usize,
1438    }
1439
1440    impl Stream for CongestedStream {
1441        type Item = Result<RecordBatch>;
1442        fn poll_next(
1443            mut self: Pin<&mut Self>,
1444            cx: &mut Context<'_>,
1445        ) -> Poll<Option<Self::Item>> {
1446            match self.partition {
1447                0 => {
1448                    let _ = self.congestion.check_congested(self.partition, cx);
1449                    if self.none_polled_once {
1450                        panic!("Exhausted stream is polled more than once")
1451                    } else {
1452                        self.none_polled_once = true;
1453                        Poll::Ready(None)
1454                    }
1455                }
1456                _ => {
1457                    ready!(self.congestion.check_congested(self.partition, cx));
1458                    Poll::Ready(None)
1459                }
1460            }
1461        }
1462    }
1463
1464    impl RecordBatchStream for CongestedStream {
1465        fn schema(&self) -> SchemaRef {
1466            Arc::clone(&self.schema)
1467        }
1468    }
1469
1470    #[tokio::test]
1471    async fn test_spm_congestion() -> Result<()> {
1472        let task_ctx = Arc::new(TaskContext::default());
1473        let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1474        let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1475        let &partition_count = match properties.output_partitioning() {
1476            Partitioning::RoundRobinBatch(partitions) => partitions,
1477            Partitioning::Hash(_, partitions) => partitions,
1478            Partitioning::UnknownPartitioning(partitions) => partitions,
1479        };
1480        let source = CongestedExec {
1481            schema: schema.clone(),
1482            cache: properties,
1483            congestion: Arc::new(Congestion::new(partition_count)),
1484        };
1485        let spm = SortPreservingMergeExec::new(
1486            [PhysicalSortExpr::new_default(Arc::new(Column::new(
1487                "c1", 0,
1488            )))]
1489            .into(),
1490            Arc::new(source),
1491        );
1492        let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1493
1494        let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1495        match result {
1496            Ok(Ok(Ok(_batches))) => Ok(()),
1497            Ok(Ok(Err(e))) => Err(e),
1498            Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1499            Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1500        }
1501    }
1502}