Skip to main content

datafusion_physical_plan/sorts/
sort_preserving_merge.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream.
19
20use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{ProjectionExec, make_with_child, update_ordering};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31    check_if_same_properties,
32};
33
34use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryConsumer;
37use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
38
39use crate::execution_plan::{EvaluationType, SchedulingType};
40use log::{debug, trace};
41
42/// Sort preserving merge execution plan
43///
44/// # Overview
45///
46/// This operator implements a K-way merge. It is used to merge multiple sorted
47/// streams into a single sorted stream and is highly optimized.
48///
49/// ## Inputs:
50///
51/// 1. A list of sort expressions
52/// 2. An input plan, where each partition is sorted with respect to
53///    these sort expressions.
54///
55/// ## Output:
56///
57/// 1. A single partition that is also sorted with respect to the expressions
58///
59/// ## Diagram
60///
61/// ```text
62/// ┌─────────────────────────┐
63/// │ ┌───┬───┬───┬───┐       │
64/// │ │ A │ B │ C │ D │ ...   │──┐
65/// │ └───┴───┴───┴───┘       │  │
66/// └─────────────────────────┘  │  ┌───────────────────┐    ┌───────────────────────────────┐
67///   Stream 1                   │  │                   │    │ ┌───┬───╦═══╦───┬───╦═══╗     │
68///                              ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │
69///                              │  │                   │    │ └───┴─▲─╩═══╩───┴───╩═══╝     │
70/// ┌─────────────────────────┐  │  └───────────────────┘    └─┬─────┴───────────────────────┘
71/// │ ╔═══╦═══╗               │  │
72/// │ ║ B ║ E ║     ...       │──┘                             │
73/// │ ╚═══╩═══╝               │              Stable sort if `enable_round_robin_repartition=false`:
74/// └─────────────────────────┘              the merged stream places equal rows from stream 1
75///   Stream 2
76///
77///
78///  Input Partitions                                          Output Partition
79///    (sorted)                                                  (sorted)
80/// ```
81///
82/// # Error Handling
83///
84/// If any of the input partitions return an error, the error is propagated to
85/// the output and inputs are not polled again.
86#[derive(Debug, Clone)]
87pub struct SortPreservingMergeExec {
88    /// Input plan with sorted partitions
89    input: Arc<dyn ExecutionPlan>,
90    /// Sort expressions
91    expr: LexOrdering,
92    /// Execution metrics
93    metrics: ExecutionPlanMetricsSet,
94    /// Optional number of rows to fetch. Stops producing rows after this fetch
95    fetch: Option<usize>,
96    /// Cache holding plan properties like equivalences, output partitioning etc.
97    cache: Arc<PlanProperties>,
98    /// Use round-robin selection of tied winners of loser tree
99    ///
100    /// See [`Self::with_round_robin_repartition`] for more information.
101    enable_round_robin_repartition: bool,
102}
103
104impl SortPreservingMergeExec {
105    /// Create a new sort execution plan
106    pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
107        let cache = Self::compute_properties(&input, expr.clone());
108        Self {
109            input,
110            expr,
111            metrics: ExecutionPlanMetricsSet::new(),
112            fetch: None,
113            cache: Arc::new(cache),
114            enable_round_robin_repartition: true,
115        }
116    }
117
118    /// Sets the number of rows to fetch
119    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
120        self.fetch = fetch;
121        self
122    }
123
124    /// Sets the selection strategy of tied winners of the loser tree algorithm
125    ///
126    /// If true (the default) equal output rows are placed in the merged stream
127    /// in round robin fashion. This approach consumes input streams at more
128    /// even rates when there are many rows with the same sort key.
129    ///
130    /// If false, equal output rows are always placed in the merged stream in
131    /// the order of the inputs, resulting in potentially slower execution but a
132    /// stable output order.
133    pub fn with_round_robin_repartition(
134        mut self,
135        enable_round_robin_repartition: bool,
136    ) -> Self {
137        self.enable_round_robin_repartition = enable_round_robin_repartition;
138        self
139    }
140
141    /// Input schema
142    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
143        &self.input
144    }
145
146    /// Sort expressions
147    pub fn expr(&self) -> &LexOrdering {
148        &self.expr
149    }
150
151    /// Fetch
152    pub fn fetch(&self) -> Option<usize> {
153        self.fetch
154    }
155
156    /// Creates the cache object that stores the plan properties
157    /// such as schema, equivalence properties, ordering, partitioning, etc.
158    fn compute_properties(
159        input: &Arc<dyn ExecutionPlan>,
160        ordering: LexOrdering,
161    ) -> PlanProperties {
162        let input_partitions = input.output_partitioning().partition_count();
163        let (drive, scheduling) = if input_partitions > 1 {
164            (EvaluationType::Eager, SchedulingType::Cooperative)
165        } else {
166            (
167                input.properties().evaluation_type,
168                input.properties().scheduling_type,
169            )
170        };
171
172        let mut eq_properties = input.equivalence_properties().clone();
173        eq_properties.clear_per_partition_constants();
174        eq_properties.add_ordering(ordering);
175        PlanProperties::new(
176            eq_properties,                        // Equivalence Properties
177            Partitioning::UnknownPartitioning(1), // Output Partitioning
178            input.pipeline_behavior(),            // Pipeline Behavior
179            input.boundedness(),                  // Boundedness
180        )
181        .with_evaluation_type(drive)
182        .with_scheduling_type(scheduling)
183    }
184
185    fn with_new_children_and_same_properties(
186        &self,
187        mut children: Vec<Arc<dyn ExecutionPlan>>,
188    ) -> Self {
189        Self {
190            input: children.swap_remove(0),
191            metrics: ExecutionPlanMetricsSet::new(),
192            ..Self::clone(self)
193        }
194    }
195}
196
197impl DisplayAs for SortPreservingMergeExec {
198    fn fmt_as(
199        &self,
200        t: DisplayFormatType,
201        f: &mut std::fmt::Formatter,
202    ) -> std::fmt::Result {
203        match t {
204            DisplayFormatType::Default | DisplayFormatType::Verbose => {
205                write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
206                if let Some(fetch) = self.fetch {
207                    write!(f, ", fetch={fetch}")?;
208                };
209
210                Ok(())
211            }
212            DisplayFormatType::TreeRender => {
213                if let Some(fetch) = self.fetch {
214                    writeln!(f, "limit={fetch}")?;
215                };
216
217                for (i, e) in self.expr().iter().enumerate() {
218                    e.fmt_sql(f)?;
219                    if i != self.expr().len() - 1 {
220                        write!(f, ", ")?;
221                    }
222                }
223
224                Ok(())
225            }
226        }
227    }
228}
229
230impl ExecutionPlan for SortPreservingMergeExec {
231    fn name(&self) -> &'static str {
232        "SortPreservingMergeExec"
233    }
234
235    /// Return a reference to Any that can be used for downcasting
236    fn as_any(&self) -> &dyn Any {
237        self
238    }
239
240    fn properties(&self) -> &Arc<PlanProperties> {
241        &self.cache
242    }
243
244    fn fetch(&self) -> Option<usize> {
245        self.fetch
246    }
247
248    /// Sets the number of rows to fetch
249    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
250        Some(Arc::new(Self {
251            input: Arc::clone(&self.input),
252            expr: self.expr.clone(),
253            metrics: self.metrics.clone(),
254            fetch: limit,
255            cache: Arc::clone(&self.cache),
256            enable_round_robin_repartition: true,
257        }))
258    }
259
260    fn with_preserve_order(
261        &self,
262        preserve_order: bool,
263    ) -> Option<Arc<dyn ExecutionPlan>> {
264        self.input
265            .with_preserve_order(preserve_order)
266            .and_then(|new_input| {
267                Arc::new(self.clone())
268                    .with_new_children(vec![new_input])
269                    .ok()
270            })
271    }
272
273    fn required_input_distribution(&self) -> Vec<Distribution> {
274        vec![Distribution::UnspecifiedDistribution]
275    }
276
277    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
278        vec![false]
279    }
280
281    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
282        vec![Some(OrderingRequirements::from(self.expr.clone()))]
283    }
284
285    fn maintains_input_order(&self) -> Vec<bool> {
286        vec![true]
287    }
288
289    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
290        vec![&self.input]
291    }
292
293    fn with_new_children(
294        self: Arc<Self>,
295        mut children: Vec<Arc<dyn ExecutionPlan>>,
296    ) -> Result<Arc<dyn ExecutionPlan>> {
297        check_if_same_properties!(self, children);
298        Ok(Arc::new(
299            SortPreservingMergeExec::new(self.expr.clone(), children.swap_remove(0))
300                .with_fetch(self.fetch),
301        ))
302    }
303
304    fn execute(
305        &self,
306        partition: usize,
307        context: Arc<TaskContext>,
308    ) -> Result<SendableRecordBatchStream> {
309        trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
310        assert_eq_or_internal_err!(
311            partition,
312            0,
313            "SortPreservingMergeExec invalid partition {partition}"
314        );
315
316        let input_partitions = self.input.output_partitioning().partition_count();
317        trace!(
318            "Number of input partitions of  SortPreservingMergeExec::execute: {input_partitions}"
319        );
320        let schema = self.schema();
321
322        let reservation =
323            MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
324                .register(&context.runtime_env().memory_pool);
325
326        match input_partitions {
327            0 => internal_err!(
328                "SortPreservingMergeExec requires at least one input partition"
329            ),
330            1 => match self.fetch {
331                Some(fetch) => {
332                    let stream = self.input.execute(0, context)?;
333                    debug!(
334                        "Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"
335                    );
336                    Ok(Box::pin(LimitStream::new(
337                        stream,
338                        0,
339                        Some(fetch),
340                        BaselineMetrics::new(&self.metrics, partition),
341                    )))
342                }
343                None => {
344                    let stream = self.input.execute(0, context);
345                    debug!(
346                        "Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"
347                    );
348                    stream
349                }
350            },
351            _ => {
352                let receivers = (0..input_partitions)
353                    .map(|partition| {
354                        let stream =
355                            self.input.execute(partition, Arc::clone(&context))?;
356                        Ok(spawn_buffered(stream, 1))
357                    })
358                    .collect::<Result<_>>()?;
359
360                debug!(
361                    "Done setting up sender-receiver for SortPreservingMergeExec::execute"
362                );
363
364                let result = StreamingMergeBuilder::new()
365                    .with_streams(receivers)
366                    .with_schema(schema)
367                    .with_expressions(&self.expr)
368                    .with_metrics(BaselineMetrics::new(&self.metrics, partition))
369                    .with_batch_size(context.session_config().batch_size())
370                    .with_fetch(self.fetch)
371                    .with_reservation(reservation)
372                    .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
373                    .build()?;
374
375                debug!(
376                    "Got stream result from SortPreservingMergeStream::new_from_receivers"
377                );
378
379                Ok(result)
380            }
381        }
382    }
383
384    fn metrics(&self) -> Option<MetricsSet> {
385        Some(self.metrics.clone_inner())
386    }
387
388    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
389        self.input.partition_statistics(None)
390    }
391
392    fn supports_limit_pushdown(&self) -> bool {
393        true
394    }
395
396    /// Tries to swap the projection with its input [`SortPreservingMergeExec`].
397    /// If this is possible, it returns the new [`SortPreservingMergeExec`] whose
398    /// child is a projection. Otherwise, it returns None.
399    fn try_swapping_with_projection(
400        &self,
401        projection: &ProjectionExec,
402    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
403        // If the projection does not narrow the schema, we should not try to push it down.
404        if projection.expr().len() >= projection.input().schema().fields().len() {
405            return Ok(None);
406        }
407
408        let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
409        else {
410            return Ok(None);
411        };
412
413        Ok(Some(Arc::new(
414            SortPreservingMergeExec::new(
415                updated_exprs,
416                make_with_child(projection, self.input())?,
417            )
418            .with_fetch(self.fetch()),
419        )))
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use std::collections::HashSet;
426    use std::fmt::Formatter;
427    use std::pin::Pin;
428    use std::sync::Mutex;
429    use std::task::{Context, Poll, Waker, ready};
430    use std::time::Duration;
431
432    use super::*;
433    use crate::coalesce_partitions::CoalescePartitionsExec;
434    use crate::execution_plan::{Boundedness, EmissionType};
435    use crate::expressions::col;
436    use crate::metrics::{MetricValue, Timestamp};
437    use crate::repartition::RepartitionExec;
438    use crate::sorts::sort::SortExec;
439    use crate::stream::RecordBatchReceiverStream;
440    use crate::test::TestMemoryExec;
441    use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
442    use crate::test::{self, assert_is_pending, make_partition};
443    use crate::{collect, common};
444
445    use arrow::array::{
446        ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
447        TimestampNanosecondArray,
448    };
449    use arrow::compute::SortOptions;
450    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
451    use datafusion_common::test_util::batches_to_string;
452    use datafusion_common::{assert_batches_eq, exec_err};
453    use datafusion_common_runtime::SpawnedTask;
454    use datafusion_execution::RecordBatchStream;
455    use datafusion_execution::config::SessionConfig;
456    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
457    use datafusion_physical_expr::EquivalenceProperties;
458    use datafusion_physical_expr::expressions::Column;
459    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
460    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
461
462    use futures::{FutureExt, Stream, StreamExt};
463    use insta::assert_snapshot;
464    use tokio::time::timeout;
465
466    // The number in the function is highly related to the memory limit we are testing
467    // any change of the constant should be aware of
468    fn generate_task_ctx_for_round_robin_tie_breaker(
469        target_batch_size: usize,
470    ) -> Result<Arc<TaskContext>> {
471        let runtime = RuntimeEnvBuilder::new()
472            .with_memory_limit(20_000_000, 1.0)
473            .build_arc()?;
474        let mut config = SessionConfig::new();
475        config.options_mut().execution.batch_size = target_batch_size;
476        let task_ctx = TaskContext::default()
477            .with_runtime(runtime)
478            .with_session_config(config);
479        Ok(Arc::new(task_ctx))
480    }
481    // The number in the function is highly related to the memory limit we are testing,
482    // any change of the constant should be aware of
483    fn generate_spm_for_round_robin_tie_breaker(
484        enable_round_robin_repartition: bool,
485    ) -> Result<Arc<SortPreservingMergeExec>> {
486        let row_size = 12500;
487        let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
488        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
489        let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
490        let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
491        let schema = rb.schema();
492
493        let rbs = std::iter::repeat_n(rb, 1024).collect::<Vec<_>>();
494        let sort = [
495            PhysicalSortExpr {
496                expr: col("b", &schema)?,
497                options: Default::default(),
498            },
499            PhysicalSortExpr {
500                expr: col("c", &schema)?,
501                options: Default::default(),
502            },
503        ]
504        .into();
505
506        let repartition_exec = RepartitionExec::try_new(
507            TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
508            Partitioning::RoundRobinBatch(2),
509        )?;
510        let spm = SortPreservingMergeExec::new(sort, Arc::new(repartition_exec))
511            .with_round_robin_repartition(enable_round_robin_repartition);
512        Ok(Arc::new(spm))
513    }
514
515    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
516    /// Any errors here could indicate unintended changes in tie breaker logic.
517    ///
518    /// Note: If you adjust constants in this test, ensure that memory usage differs
519    /// based on whether the tie breaker is enabled or disabled.
520    #[tokio::test(flavor = "multi_thread")]
521    async fn test_round_robin_tie_breaker_success() -> Result<()> {
522        let target_batch_size = 12500;
523        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(target_batch_size)?;
524        let spm = generate_spm_for_round_robin_tie_breaker(true)?;
525        let _collected = collect(spm, task_ctx).await?;
526        Ok(())
527    }
528
529    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
530    /// Any errors here could indicate unintended changes in tie breaker logic.
531    ///
532    /// Note: If you adjust constants in this test, ensure that memory usage differs
533    /// based on whether the tie breaker is enabled or disabled.
534    #[tokio::test(flavor = "multi_thread")]
535    async fn test_round_robin_tie_breaker_fail() -> Result<()> {
536        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(8192)?;
537        let spm = generate_spm_for_round_robin_tie_breaker(false)?;
538        let _err = collect(spm, task_ctx).await.unwrap_err();
539        Ok(())
540    }
541
542    #[tokio::test]
543    async fn test_merge_interleave() {
544        let task_ctx = Arc::new(TaskContext::default());
545        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
546        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
547            Some("a"),
548            Some("c"),
549            Some("e"),
550            Some("g"),
551            Some("j"),
552        ]));
553        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
554        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
555
556        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
557        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
558            Some("b"),
559            Some("d"),
560            Some("f"),
561            Some("h"),
562            Some("j"),
563        ]));
564        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
565        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
566
567        _test_merge(
568            &[vec![b1], vec![b2]],
569            &[
570                "+----+---+-------------------------------+",
571                "| a  | b | c                             |",
572                "+----+---+-------------------------------+",
573                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
574                "| 10 | b | 1970-01-01T00:00:00.000000004 |",
575                "| 2  | c | 1970-01-01T00:00:00.000000007 |",
576                "| 20 | d | 1970-01-01T00:00:00.000000006 |",
577                "| 7  | e | 1970-01-01T00:00:00.000000006 |",
578                "| 70 | f | 1970-01-01T00:00:00.000000002 |",
579                "| 9  | g | 1970-01-01T00:00:00.000000005 |",
580                "| 90 | h | 1970-01-01T00:00:00.000000002 |",
581                "| 30 | j | 1970-01-01T00:00:00.000000006 |", // input b2 before b1
582                "| 3  | j | 1970-01-01T00:00:00.000000008 |",
583                "+----+---+-------------------------------+",
584            ],
585            task_ctx,
586        )
587        .await;
588    }
589
590    #[tokio::test]
591    async fn test_merge_some_overlap() {
592        let task_ctx = Arc::new(TaskContext::default());
593        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
594        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
595            Some("a"),
596            Some("b"),
597            Some("c"),
598            Some("d"),
599            Some("e"),
600        ]));
601        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
602        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
603
604        let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
605        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
606            Some("c"),
607            Some("d"),
608            Some("e"),
609            Some("f"),
610            Some("g"),
611        ]));
612        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
613        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
614
615        _test_merge(
616            &[vec![b1], vec![b2]],
617            &[
618                "+-----+---+-------------------------------+",
619                "| a   | b | c                             |",
620                "+-----+---+-------------------------------+",
621                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
622                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
623                "| 70  | c | 1970-01-01T00:00:00.000000004 |",
624                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
625                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
626                "| 90  | d | 1970-01-01T00:00:00.000000006 |",
627                "| 30  | e | 1970-01-01T00:00:00.000000002 |",
628                "| 3   | e | 1970-01-01T00:00:00.000000008 |",
629                "| 100 | f | 1970-01-01T00:00:00.000000002 |",
630                "| 110 | g | 1970-01-01T00:00:00.000000006 |",
631                "+-----+---+-------------------------------+",
632            ],
633            task_ctx,
634        )
635        .await;
636    }
637
638    #[tokio::test]
639    async fn test_merge_no_overlap() {
640        let task_ctx = Arc::new(TaskContext::default());
641        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
642        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
643            Some("a"),
644            Some("b"),
645            Some("c"),
646            Some("d"),
647            Some("e"),
648        ]));
649        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
650        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
651
652        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
653        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
654            Some("f"),
655            Some("g"),
656            Some("h"),
657            Some("i"),
658            Some("j"),
659        ]));
660        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
661        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
662
663        _test_merge(
664            &[vec![b1], vec![b2]],
665            &[
666                "+----+---+-------------------------------+",
667                "| a  | b | c                             |",
668                "+----+---+-------------------------------+",
669                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
670                "| 2  | b | 1970-01-01T00:00:00.000000007 |",
671                "| 7  | c | 1970-01-01T00:00:00.000000006 |",
672                "| 9  | d | 1970-01-01T00:00:00.000000005 |",
673                "| 3  | e | 1970-01-01T00:00:00.000000008 |",
674                "| 10 | f | 1970-01-01T00:00:00.000000004 |",
675                "| 20 | g | 1970-01-01T00:00:00.000000006 |",
676                "| 70 | h | 1970-01-01T00:00:00.000000002 |",
677                "| 90 | i | 1970-01-01T00:00:00.000000002 |",
678                "| 30 | j | 1970-01-01T00:00:00.000000006 |",
679                "+----+---+-------------------------------+",
680            ],
681            task_ctx,
682        )
683        .await;
684    }
685
686    #[tokio::test]
687    async fn test_merge_three_partitions() {
688        let task_ctx = Arc::new(TaskContext::default());
689        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
690        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
691            Some("a"),
692            Some("b"),
693            Some("c"),
694            Some("d"),
695            Some("f"),
696        ]));
697        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
698        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
699
700        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
701        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
702            Some("e"),
703            Some("g"),
704            Some("h"),
705            Some("i"),
706            Some("j"),
707        ]));
708        let c: ArrayRef =
709            Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
710        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
711
712        let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
713        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
714            Some("f"),
715            Some("g"),
716            Some("h"),
717            Some("i"),
718            Some("j"),
719        ]));
720        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
721        let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
722
723        _test_merge(
724            &[vec![b1], vec![b2], vec![b3]],
725            &[
726                "+-----+---+-------------------------------+",
727                "| a   | b | c                             |",
728                "+-----+---+-------------------------------+",
729                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
730                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
731                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
732                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
733                "| 10  | e | 1970-01-01T00:00:00.000000040 |",
734                "| 100 | f | 1970-01-01T00:00:00.000000004 |",
735                "| 3   | f | 1970-01-01T00:00:00.000000008 |",
736                "| 200 | g | 1970-01-01T00:00:00.000000006 |",
737                "| 20  | g | 1970-01-01T00:00:00.000000060 |",
738                "| 700 | h | 1970-01-01T00:00:00.000000002 |",
739                "| 70  | h | 1970-01-01T00:00:00.000000020 |",
740                "| 900 | i | 1970-01-01T00:00:00.000000002 |",
741                "| 90  | i | 1970-01-01T00:00:00.000000020 |",
742                "| 300 | j | 1970-01-01T00:00:00.000000006 |",
743                "| 30  | j | 1970-01-01T00:00:00.000000060 |",
744                "+-----+---+-------------------------------+",
745            ],
746            task_ctx,
747        )
748        .await;
749    }
750
751    async fn _test_merge(
752        partitions: &[Vec<RecordBatch>],
753        exp: &[&str],
754        context: Arc<TaskContext>,
755    ) {
756        let schema = partitions[0][0].schema();
757        let sort = [
758            PhysicalSortExpr {
759                expr: col("b", &schema).unwrap(),
760                options: Default::default(),
761            },
762            PhysicalSortExpr {
763                expr: col("c", &schema).unwrap(),
764                options: Default::default(),
765            },
766        ]
767        .into();
768        let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
769        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
770
771        let collected = collect(merge, context).await.unwrap();
772        assert_batches_eq!(exp, collected.as_slice());
773    }
774
775    async fn sorted_merge(
776        input: Arc<dyn ExecutionPlan>,
777        sort: LexOrdering,
778        context: Arc<TaskContext>,
779    ) -> RecordBatch {
780        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
781        let mut result = collect(merge, context).await.unwrap();
782        assert_eq!(result.len(), 1);
783        result.remove(0)
784    }
785
786    async fn partition_sort(
787        input: Arc<dyn ExecutionPlan>,
788        sort: LexOrdering,
789        context: Arc<TaskContext>,
790    ) -> RecordBatch {
791        let sort_exec =
792            Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
793        sorted_merge(sort_exec, sort, context).await
794    }
795
796    async fn basic_sort(
797        src: Arc<dyn ExecutionPlan>,
798        sort: LexOrdering,
799        context: Arc<TaskContext>,
800    ) -> RecordBatch {
801        let merge = Arc::new(CoalescePartitionsExec::new(src));
802        let sort_exec = Arc::new(SortExec::new(sort, merge));
803        let mut result = collect(sort_exec, context).await.unwrap();
804        assert_eq!(result.len(), 1);
805        result.remove(0)
806    }
807
808    #[tokio::test]
809    async fn test_partition_sort() -> Result<()> {
810        let task_ctx = Arc::new(TaskContext::default());
811        let partitions = 4;
812        let csv = test::scan_partitioned(partitions);
813        let schema = csv.schema();
814
815        let sort: LexOrdering = [PhysicalSortExpr {
816            expr: col("i", &schema)?,
817            options: SortOptions {
818                descending: true,
819                nulls_first: true,
820            },
821        }]
822        .into();
823
824        let basic =
825            basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
826        let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
827
828        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
829            .unwrap()
830            .to_string();
831        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
832            .unwrap()
833            .to_string();
834
835        assert_eq!(
836            basic, partition,
837            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
838        );
839
840        Ok(())
841    }
842
843    // Split the provided record batch into multiple batch_size record batches
844    fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
845        let batches = sorted.num_rows().div_ceil(batch_size);
846
847        // Split the sorted RecordBatch into multiple
848        (0..batches)
849            .map(|batch_idx| {
850                let columns = (0..sorted.num_columns())
851                    .map(|column_idx| {
852                        let length =
853                            batch_size.min(sorted.num_rows() - batch_idx * batch_size);
854
855                        sorted
856                            .column(column_idx)
857                            .slice(batch_idx * batch_size, length)
858                    })
859                    .collect();
860
861                RecordBatch::try_new(sorted.schema(), columns).unwrap()
862            })
863            .collect()
864    }
865
866    async fn sorted_partitioned_input(
867        sort: LexOrdering,
868        sizes: &[usize],
869        context: Arc<TaskContext>,
870    ) -> Result<Arc<dyn ExecutionPlan>> {
871        let partitions = 4;
872        let csv = test::scan_partitioned(partitions);
873
874        let sorted = basic_sort(csv, sort, context).await;
875        let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
876
877        TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
878    }
879
880    #[tokio::test]
881    async fn test_partition_sort_streaming_input() -> Result<()> {
882        let task_ctx = Arc::new(TaskContext::default());
883        let schema = make_partition(11).schema();
884        let sort: LexOrdering = [PhysicalSortExpr {
885            expr: col("i", &schema)?,
886            options: Default::default(),
887        }]
888        .into();
889
890        let input =
891            sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
892                .await?;
893        let basic =
894            basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
895        let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
896
897        assert_eq!(basic.num_rows(), 1200);
898        assert_eq!(partition.num_rows(), 1200);
899
900        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
901        let partition =
902            arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
903
904        assert_eq!(basic, partition);
905
906        Ok(())
907    }
908
909    #[tokio::test]
910    async fn test_partition_sort_streaming_input_output() -> Result<()> {
911        let schema = make_partition(11).schema();
912        let sort: LexOrdering = [PhysicalSortExpr {
913            expr: col("i", &schema)?,
914            options: Default::default(),
915        }]
916        .into();
917
918        // Test streaming with default batch size
919        let task_ctx = Arc::new(TaskContext::default());
920        let input =
921            sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
922                .await?;
923        let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
924
925        // batch size of 23
926        let task_ctx = TaskContext::default()
927            .with_session_config(SessionConfig::new().with_batch_size(23));
928        let task_ctx = Arc::new(task_ctx);
929
930        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
931        let merged = collect(merge, task_ctx).await?;
932
933        assert_eq!(merged.len(), 53);
934        assert_eq!(basic.num_rows(), 1200);
935        assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
936
937        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
938        let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
939
940        assert_eq!(basic, partition);
941
942        Ok(())
943    }
944
945    #[tokio::test]
946    async fn test_nulls() {
947        let task_ctx = Arc::new(TaskContext::default());
948        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
949        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
950            None,
951            Some("a"),
952            Some("b"),
953            Some("d"),
954            Some("e"),
955        ]));
956        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
957            Some(8),
958            None,
959            Some(6),
960            None,
961            Some(4),
962        ]));
963        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
964
965        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
966        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
967            None,
968            Some("b"),
969            Some("g"),
970            Some("h"),
971            Some("i"),
972        ]));
973        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
974            Some(8),
975            None,
976            Some(5),
977            None,
978            Some(4),
979        ]));
980        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
981        let schema = b1.schema();
982
983        let sort = [
984            PhysicalSortExpr {
985                expr: col("b", &schema).unwrap(),
986                options: SortOptions {
987                    descending: false,
988                    nulls_first: true,
989                },
990            },
991            PhysicalSortExpr {
992                expr: col("c", &schema).unwrap(),
993                options: SortOptions {
994                    descending: false,
995                    nulls_first: false,
996                },
997            },
998        ]
999        .into();
1000        let exec =
1001            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1002        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1003
1004        let collected = collect(merge, task_ctx).await.unwrap();
1005        assert_eq!(collected.len(), 1);
1006
1007        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1008        +---+---+-------------------------------+
1009        | a | b | c                             |
1010        +---+---+-------------------------------+
1011        | 1 |   | 1970-01-01T00:00:00.000000008 |
1012        | 1 |   | 1970-01-01T00:00:00.000000008 |
1013        | 2 | a |                               |
1014        | 7 | b | 1970-01-01T00:00:00.000000006 |
1015        | 2 | b |                               |
1016        | 9 | d |                               |
1017        | 3 | e | 1970-01-01T00:00:00.000000004 |
1018        | 3 | g | 1970-01-01T00:00:00.000000005 |
1019        | 4 | h |                               |
1020        | 5 | i | 1970-01-01T00:00:00.000000004 |
1021        +---+---+-------------------------------+
1022        ");
1023    }
1024
1025    #[tokio::test]
1026    async fn test_sort_merge_single_partition_with_fetch() {
1027        let task_ctx = Arc::new(TaskContext::default());
1028        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1029        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1030        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1031        let schema = batch.schema();
1032
1033        let sort = [PhysicalSortExpr {
1034            expr: col("b", &schema).unwrap(),
1035            options: SortOptions {
1036                descending: false,
1037                nulls_first: true,
1038            },
1039        }]
1040        .into();
1041        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1042        let merge =
1043            Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1044
1045        let collected = collect(merge, task_ctx).await.unwrap();
1046        assert_eq!(collected.len(), 1);
1047
1048        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1049        +---+---+
1050        | a | b |
1051        +---+---+
1052        | 1 | a |
1053        | 2 | b |
1054        +---+---+
1055        ");
1056    }
1057
1058    #[tokio::test]
1059    async fn test_sort_merge_single_partition_without_fetch() {
1060        let task_ctx = Arc::new(TaskContext::default());
1061        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1062        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1063        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1064        let schema = batch.schema();
1065
1066        let sort = [PhysicalSortExpr {
1067            expr: col("b", &schema).unwrap(),
1068            options: SortOptions {
1069                descending: false,
1070                nulls_first: true,
1071            },
1072        }]
1073        .into();
1074        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1075        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1076
1077        let collected = collect(merge, task_ctx).await.unwrap();
1078        assert_eq!(collected.len(), 1);
1079
1080        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1081        +---+---+
1082        | a | b |
1083        +---+---+
1084        | 1 | a |
1085        | 2 | b |
1086        | 7 | c |
1087        | 9 | d |
1088        | 3 | e |
1089        +---+---+
1090        ");
1091    }
1092
1093    #[tokio::test]
1094    async fn test_async() -> Result<()> {
1095        let task_ctx = Arc::new(TaskContext::default());
1096        let schema = make_partition(11).schema();
1097        let sort: LexOrdering = [PhysicalSortExpr {
1098            expr: col("i", &schema).unwrap(),
1099            options: SortOptions::default(),
1100        }]
1101        .into();
1102
1103        let batches =
1104            sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1105                .await?;
1106
1107        let partition_count = batches.output_partitioning().partition_count();
1108        let mut streams = Vec::with_capacity(partition_count);
1109
1110        for partition in 0..partition_count {
1111            let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1112
1113            let sender = builder.tx();
1114
1115            let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1116            builder.spawn(async move {
1117                while let Some(batch) = stream.next().await {
1118                    sender.send(batch).await.unwrap();
1119                    // This causes the MergeStream to wait for more input
1120                    tokio::time::sleep(Duration::from_millis(10)).await;
1121                }
1122
1123                Ok(())
1124            });
1125
1126            streams.push(builder.build());
1127        }
1128
1129        let metrics = ExecutionPlanMetricsSet::new();
1130        let reservation =
1131            MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1132
1133        let fetch = None;
1134        let merge_stream = StreamingMergeBuilder::new()
1135            .with_streams(streams)
1136            .with_schema(batches.schema())
1137            .with_expressions(&sort)
1138            .with_metrics(BaselineMetrics::new(&metrics, 0))
1139            .with_batch_size(task_ctx.session_config().batch_size())
1140            .with_fetch(fetch)
1141            .with_reservation(reservation)
1142            .build()?;
1143
1144        let mut merged = common::collect(merge_stream).await.unwrap();
1145
1146        assert_eq!(merged.len(), 1);
1147        let merged = merged.remove(0);
1148        let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1149
1150        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1151            .unwrap()
1152            .to_string();
1153        let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1154            .unwrap()
1155            .to_string();
1156
1157        assert_eq!(
1158            basic, partition,
1159            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1160        );
1161
1162        Ok(())
1163    }
1164
1165    #[tokio::test]
1166    async fn test_merge_metrics() {
1167        let task_ctx = Arc::new(TaskContext::default());
1168        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1169        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1170        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1171
1172        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1173        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1174        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1175
1176        let schema = b1.schema();
1177        let sort = [PhysicalSortExpr {
1178            expr: col("b", &schema).unwrap(),
1179            options: Default::default(),
1180        }]
1181        .into();
1182        let exec =
1183            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1184        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1185
1186        let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1187            .await
1188            .unwrap();
1189        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1190        +----+---+
1191        | a  | b |
1192        +----+---+
1193        | 1  | a |
1194        | 10 | b |
1195        | 2  | c |
1196        | 20 | d |
1197        +----+---+
1198        ");
1199
1200        // Now, validate metrics
1201        let metrics = merge.metrics().unwrap();
1202
1203        assert_eq!(metrics.output_rows().unwrap(), 4);
1204        assert!(metrics.elapsed_compute().unwrap() > 0);
1205
1206        let mut saw_start = false;
1207        let mut saw_end = false;
1208        metrics.iter().for_each(|m| match m.value() {
1209            MetricValue::StartTimestamp(ts) => {
1210                saw_start = true;
1211                assert!(nanos_from_timestamp(ts) > 0);
1212            }
1213            MetricValue::EndTimestamp(ts) => {
1214                saw_end = true;
1215                assert!(nanos_from_timestamp(ts) > 0);
1216            }
1217            _ => {}
1218        });
1219
1220        assert!(saw_start);
1221        assert!(saw_end);
1222    }
1223
1224    fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1225        ts.value().unwrap().timestamp_nanos_opt().unwrap()
1226    }
1227
1228    #[tokio::test]
1229    async fn test_drop_cancel() -> Result<()> {
1230        let task_ctx = Arc::new(TaskContext::default());
1231        let schema =
1232            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1233
1234        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1235        let refs = blocking_exec.refs();
1236        let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1237            [PhysicalSortExpr {
1238                expr: col("a", &schema)?,
1239                options: SortOptions::default(),
1240            }]
1241            .into(),
1242            blocking_exec,
1243        ));
1244
1245        let fut = collect(sort_preserving_merge_exec, task_ctx);
1246        let mut fut = fut.boxed();
1247
1248        assert_is_pending(&mut fut);
1249        drop(fut);
1250        assert_strong_count_converges_to_zero(refs).await;
1251
1252        Ok(())
1253    }
1254
1255    #[tokio::test]
1256    async fn test_stable_sort() {
1257        let task_ctx = Arc::new(TaskContext::default());
1258
1259        // Create record batches like:
1260        // batch_number |value
1261        // -------------+------
1262        //    1         | A
1263        //    1         | B
1264        //
1265        // Ensure that the output is in the same order the batches were fed
1266        let partitions: Vec<Vec<RecordBatch>> = (0..10)
1267            .map(|batch_number| {
1268                let batch_number: Int32Array =
1269                    vec![Some(batch_number), Some(batch_number)]
1270                        .into_iter()
1271                        .collect();
1272                let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1273
1274                let batch = RecordBatch::try_from_iter(vec![
1275                    ("batch_number", Arc::new(batch_number) as ArrayRef),
1276                    ("value", Arc::new(value) as ArrayRef),
1277                ])
1278                .unwrap();
1279
1280                vec![batch]
1281            })
1282            .collect();
1283
1284        let schema = partitions[0][0].schema();
1285
1286        let sort = [PhysicalSortExpr {
1287            expr: col("value", &schema).unwrap(),
1288            options: SortOptions {
1289                descending: false,
1290                nulls_first: true,
1291            },
1292        }]
1293        .into();
1294
1295        let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1296        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1297
1298        let collected = collect(merge, task_ctx).await.unwrap();
1299        assert_eq!(collected.len(), 1);
1300
1301        // Expect the data to be sorted first by "batch_number" (because
1302        // that was the order it was fed in, even though only "value"
1303        // is in the sort key)
1304        assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1305        +--------------+-------+
1306        | batch_number | value |
1307        +--------------+-------+
1308        | 0            | A     |
1309        | 1            | A     |
1310        | 2            | A     |
1311        | 3            | A     |
1312        | 4            | A     |
1313        | 5            | A     |
1314        | 6            | A     |
1315        | 7            | A     |
1316        | 8            | A     |
1317        | 9            | A     |
1318        | 0            | B     |
1319        | 1            | B     |
1320        | 2            | B     |
1321        | 3            | B     |
1322        | 4            | B     |
1323        | 5            | B     |
1324        | 6            | B     |
1325        | 7            | B     |
1326        | 8            | B     |
1327        | 9            | B     |
1328        +--------------+-------+
1329        ");
1330    }
1331
1332    #[derive(Debug)]
1333    struct CongestionState {
1334        wakers: Vec<Waker>,
1335        unpolled_partitions: HashSet<usize>,
1336    }
1337
1338    #[derive(Debug)]
1339    struct Congestion {
1340        congestion_state: Mutex<CongestionState>,
1341    }
1342
1343    impl Congestion {
1344        fn new(partition_count: usize) -> Self {
1345            Congestion {
1346                congestion_state: Mutex::new(CongestionState {
1347                    wakers: vec![],
1348                    unpolled_partitions: (0usize..partition_count).collect(),
1349                }),
1350            }
1351        }
1352
1353        fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1354            let mut state = self.congestion_state.lock().unwrap();
1355
1356            state.unpolled_partitions.remove(&partition);
1357
1358            if state.unpolled_partitions.is_empty() {
1359                state.wakers.iter().for_each(|w| w.wake_by_ref());
1360                state.wakers.clear();
1361                Poll::Ready(())
1362            } else {
1363                state.wakers.push(cx.waker().clone());
1364                Poll::Pending
1365            }
1366        }
1367    }
1368
1369    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1370    /// partition is exhausted from the start, and if it is polled more than one, it panics.
1371    #[derive(Debug, Clone)]
1372    struct CongestedExec {
1373        schema: Schema,
1374        cache: Arc<PlanProperties>,
1375        congestion: Arc<Congestion>,
1376    }
1377
1378    impl CongestedExec {
1379        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1380            let columns = schema
1381                .fields
1382                .iter()
1383                .enumerate()
1384                .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1385                .collect::<Vec<_>>();
1386            let mut eq_properties = EquivalenceProperties::new(schema);
1387            eq_properties.add_ordering(
1388                columns
1389                    .iter()
1390                    .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1391            );
1392            PlanProperties::new(
1393                eq_properties,
1394                Partitioning::Hash(columns, 3),
1395                EmissionType::Incremental,
1396                Boundedness::Unbounded {
1397                    requires_infinite_memory: false,
1398                },
1399            )
1400        }
1401    }
1402
1403    impl ExecutionPlan for CongestedExec {
1404        fn name(&self) -> &'static str {
1405            Self::static_name()
1406        }
1407        fn as_any(&self) -> &dyn Any {
1408            self
1409        }
1410        fn properties(&self) -> &Arc<PlanProperties> {
1411            &self.cache
1412        }
1413        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1414            vec![]
1415        }
1416        fn with_new_children(
1417            self: Arc<Self>,
1418            _: Vec<Arc<dyn ExecutionPlan>>,
1419        ) -> Result<Arc<dyn ExecutionPlan>> {
1420            Ok(self)
1421        }
1422        fn execute(
1423            &self,
1424            partition: usize,
1425            _context: Arc<TaskContext>,
1426        ) -> Result<SendableRecordBatchStream> {
1427            Ok(Box::pin(CongestedStream {
1428                schema: Arc::new(self.schema.clone()),
1429                none_polled_once: false,
1430                congestion: Arc::clone(&self.congestion),
1431                partition,
1432            }))
1433        }
1434    }
1435
1436    impl DisplayAs for CongestedExec {
1437        fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1438            match t {
1439                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1440                    write!(f, "CongestedExec",).unwrap()
1441                }
1442                DisplayFormatType::TreeRender => {
1443                    // TODO: collect info
1444                    write!(f, "").unwrap()
1445                }
1446            }
1447            Ok(())
1448        }
1449    }
1450
1451    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1452    /// partition is exhausted from the start, and if it is polled more than once, it panics.
1453    #[derive(Debug)]
1454    pub struct CongestedStream {
1455        schema: SchemaRef,
1456        none_polled_once: bool,
1457        congestion: Arc<Congestion>,
1458        partition: usize,
1459    }
1460
1461    impl Stream for CongestedStream {
1462        type Item = Result<RecordBatch>;
1463        fn poll_next(
1464            mut self: Pin<&mut Self>,
1465            cx: &mut Context<'_>,
1466        ) -> Poll<Option<Self::Item>> {
1467            match self.partition {
1468                0 => {
1469                    let _ = self.congestion.check_congested(self.partition, cx);
1470                    if self.none_polled_once {
1471                        panic!("Exhausted stream is polled more than once")
1472                    } else {
1473                        self.none_polled_once = true;
1474                        Poll::Ready(None)
1475                    }
1476                }
1477                _ => {
1478                    ready!(self.congestion.check_congested(self.partition, cx));
1479                    Poll::Ready(None)
1480                }
1481            }
1482        }
1483    }
1484
1485    impl RecordBatchStream for CongestedStream {
1486        fn schema(&self) -> SchemaRef {
1487            Arc::clone(&self.schema)
1488        }
1489    }
1490
1491    #[tokio::test]
1492    async fn test_spm_congestion() -> Result<()> {
1493        let task_ctx = Arc::new(TaskContext::default());
1494        let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1495        let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1496        let &partition_count = match properties.output_partitioning() {
1497            Partitioning::RoundRobinBatch(partitions) => partitions,
1498            Partitioning::Hash(_, partitions) => partitions,
1499            Partitioning::UnknownPartitioning(partitions) => partitions,
1500        };
1501        let source = CongestedExec {
1502            schema: schema.clone(),
1503            cache: Arc::new(properties),
1504            congestion: Arc::new(Congestion::new(partition_count)),
1505        };
1506        let spm = SortPreservingMergeExec::new(
1507            [PhysicalSortExpr::new_default(Arc::new(Column::new(
1508                "c1", 0,
1509            )))]
1510            .into(),
1511            Arc::new(source),
1512        );
1513        let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1514
1515        let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1516        match result {
1517            Ok(Ok(Ok(_batches))) => Ok(()),
1518            Ok(Ok(Err(e))) => Err(e),
1519            Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1520            Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1521        }
1522    }
1523}