datafusion_physical_plan/sorts/
sort_preserving_merge.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream.
19
20use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{make_with_child, update_ordering, ProjectionExec};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31};
32
33use datafusion_common::{internal_err, Result};
34use datafusion_execution::memory_pool::MemoryConsumer;
35use datafusion_execution::TaskContext;
36use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
37
38use crate::execution_plan::{EvaluationType, SchedulingType};
39use log::{debug, trace};
40
41/// Sort preserving merge execution plan
42///
43/// # Overview
44///
45/// This operator implements a K-way merge. It is used to merge multiple sorted
46/// streams into a single sorted stream and is highly optimized.
47///
48/// ## Inputs:
49///
50/// 1. A list of sort expressions
51/// 2. An input plan, where each partition is sorted with respect to
52///    these sort expressions.
53///
54/// ## Output:
55///
56/// 1. A single partition that is also sorted with respect to the expressions
57///
58/// ## Diagram
59///
60/// ```text
61/// ┌─────────────────────────┐
62/// │ ┌───┬───┬───┬───┐       │
63/// │ │ A │ B │ C │ D │ ...   │──┐
64/// │ └───┴───┴───┴───┘       │  │
65/// └─────────────────────────┘  │  ┌───────────────────┐    ┌───────────────────────────────┐
66///   Stream 1                   │  │                   │    │ ┌───┬───╦═══╦───┬───╦═══╗     │
67///                              ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │
68///                              │  │                   │    │ └───┴─▲─╩═══╩───┴───╩═══╝     │
69/// ┌─────────────────────────┐  │  └───────────────────┘    └─┬─────┴───────────────────────┘
70/// │ ╔═══╦═══╗               │  │
71/// │ ║ B ║ E ║     ...       │──┘                             │
72/// │ ╚═══╩═══╝               │              Stable sort if `enable_round_robin_repartition=false`:
73/// └─────────────────────────┘              the merged stream places equal rows from stream 1
74///   Stream 2
75///
76///
77///  Input Partitions                                          Output Partition
78///    (sorted)                                                  (sorted)
79/// ```
80///
81/// # Error Handling
82///
83/// If any of the input partitions return an error, the error is propagated to
84/// the output and inputs are not polled again.
85#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87    /// Input plan with sorted partitions
88    input: Arc<dyn ExecutionPlan>,
89    /// Sort expressions
90    expr: LexOrdering,
91    /// Execution metrics
92    metrics: ExecutionPlanMetricsSet,
93    /// Optional number of rows to fetch. Stops producing rows after this fetch
94    fetch: Option<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97    /// Use round-robin selection of tied winners of loser tree
98    ///
99    /// See [`Self::with_round_robin_repartition`] for more information.
100    enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104    /// Create a new sort execution plan
105    pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106        let cache = Self::compute_properties(&input, expr.clone());
107        Self {
108            input,
109            expr,
110            metrics: ExecutionPlanMetricsSet::new(),
111            fetch: None,
112            cache,
113            enable_round_robin_repartition: true,
114        }
115    }
116
117    /// Sets the number of rows to fetch
118    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119        self.fetch = fetch;
120        self
121    }
122
123    /// Sets the selection strategy of tied winners of the loser tree algorithm
124    ///
125    /// If true (the default) equal output rows are placed in the merged stream
126    /// in round robin fashion. This approach consumes input streams at more
127    /// even rates when there are many rows with the same sort key.
128    ///
129    /// If false, equal output rows are always placed in the merged stream in
130    /// the order of the inputs, resulting in potentially slower execution but a
131    /// stable output order.
132    pub fn with_round_robin_repartition(
133        mut self,
134        enable_round_robin_repartition: bool,
135    ) -> Self {
136        self.enable_round_robin_repartition = enable_round_robin_repartition;
137        self
138    }
139
140    /// Input schema
141    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142        &self.input
143    }
144
145    /// Sort expressions
146    pub fn expr(&self) -> &LexOrdering {
147        &self.expr
148    }
149
150    /// Fetch
151    pub fn fetch(&self) -> Option<usize> {
152        self.fetch
153    }
154
155    /// Creates the cache object that stores the plan properties
156    /// such as schema, equivalence properties, ordering, partitioning, etc.
157    fn compute_properties(
158        input: &Arc<dyn ExecutionPlan>,
159        ordering: LexOrdering,
160    ) -> PlanProperties {
161        let input_partitions = input.output_partitioning().partition_count();
162        let (drive, scheduling) = if input_partitions > 1 {
163            (EvaluationType::Eager, SchedulingType::Cooperative)
164        } else {
165            (
166                input.properties().evaluation_type,
167                input.properties().scheduling_type,
168            )
169        };
170
171        let mut eq_properties = input.equivalence_properties().clone();
172        eq_properties.clear_per_partition_constants();
173        eq_properties.add_ordering(ordering);
174        PlanProperties::new(
175            eq_properties,                        // Equivalence Properties
176            Partitioning::UnknownPartitioning(1), // Output Partitioning
177            input.pipeline_behavior(),            // Pipeline Behavior
178            input.boundedness(),                  // Boundedness
179        )
180        .with_evaluation_type(drive)
181        .with_scheduling_type(scheduling)
182    }
183}
184
185impl DisplayAs for SortPreservingMergeExec {
186    fn fmt_as(
187        &self,
188        t: DisplayFormatType,
189        f: &mut std::fmt::Formatter,
190    ) -> std::fmt::Result {
191        match t {
192            DisplayFormatType::Default | DisplayFormatType::Verbose => {
193                write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
194                if let Some(fetch) = self.fetch {
195                    write!(f, ", fetch={fetch}")?;
196                };
197
198                Ok(())
199            }
200            DisplayFormatType::TreeRender => {
201                if let Some(fetch) = self.fetch {
202                    writeln!(f, "limit={fetch}")?;
203                };
204
205                for (i, e) in self.expr().iter().enumerate() {
206                    e.fmt_sql(f)?;
207                    if i != self.expr().len() - 1 {
208                        write!(f, ", ")?;
209                    }
210                }
211
212                Ok(())
213            }
214        }
215    }
216}
217
218impl ExecutionPlan for SortPreservingMergeExec {
219    fn name(&self) -> &'static str {
220        "SortPreservingMergeExec"
221    }
222
223    /// Return a reference to Any that can be used for downcasting
224    fn as_any(&self) -> &dyn Any {
225        self
226    }
227
228    fn properties(&self) -> &PlanProperties {
229        &self.cache
230    }
231
232    fn fetch(&self) -> Option<usize> {
233        self.fetch
234    }
235
236    /// Sets the number of rows to fetch
237    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
238        Some(Arc::new(Self {
239            input: Arc::clone(&self.input),
240            expr: self.expr.clone(),
241            metrics: self.metrics.clone(),
242            fetch: limit,
243            cache: self.cache.clone(),
244            enable_round_robin_repartition: true,
245        }))
246    }
247
248    fn required_input_distribution(&self) -> Vec<Distribution> {
249        vec![Distribution::UnspecifiedDistribution]
250    }
251
252    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
253        vec![false]
254    }
255
256    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
257        vec![Some(OrderingRequirements::from(self.expr.clone()))]
258    }
259
260    fn maintains_input_order(&self) -> Vec<bool> {
261        vec![true]
262    }
263
264    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
265        vec![&self.input]
266    }
267
268    fn with_new_children(
269        self: Arc<Self>,
270        children: Vec<Arc<dyn ExecutionPlan>>,
271    ) -> Result<Arc<dyn ExecutionPlan>> {
272        Ok(Arc::new(
273            SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
274                .with_fetch(self.fetch),
275        ))
276    }
277
278    fn execute(
279        &self,
280        partition: usize,
281        context: Arc<TaskContext>,
282    ) -> Result<SendableRecordBatchStream> {
283        trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
284        if 0 != partition {
285            return internal_err!(
286                "SortPreservingMergeExec invalid partition {partition}"
287            );
288        }
289
290        let input_partitions = self.input.output_partitioning().partition_count();
291        trace!(
292            "Number of input partitions of  SortPreservingMergeExec::execute: {input_partitions}"
293        );
294        let schema = self.schema();
295
296        let reservation =
297            MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
298                .register(&context.runtime_env().memory_pool);
299
300        match input_partitions {
301            0 => internal_err!(
302                "SortPreservingMergeExec requires at least one input partition"
303            ),
304            1 => match self.fetch {
305                Some(fetch) => {
306                    let stream = self.input.execute(0, context)?;
307                    debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}");
308                    Ok(Box::pin(LimitStream::new(
309                        stream,
310                        0,
311                        Some(fetch),
312                        BaselineMetrics::new(&self.metrics, partition),
313                    )))
314                }
315                None => {
316                    let stream = self.input.execute(0, context);
317                    debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch");
318                    stream
319                }
320            },
321            _ => {
322                let receivers = (0..input_partitions)
323                    .map(|partition| {
324                        let stream =
325                            self.input.execute(partition, Arc::clone(&context))?;
326                        Ok(spawn_buffered(stream, 1))
327                    })
328                    .collect::<Result<_>>()?;
329
330                debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");
331
332                let result = StreamingMergeBuilder::new()
333                    .with_streams(receivers)
334                    .with_schema(schema)
335                    .with_expressions(&self.expr)
336                    .with_metrics(BaselineMetrics::new(&self.metrics, partition))
337                    .with_batch_size(context.session_config().batch_size())
338                    .with_fetch(self.fetch)
339                    .with_reservation(reservation)
340                    .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
341                    .build()?;
342
343                debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
344
345                Ok(result)
346            }
347        }
348    }
349
350    fn metrics(&self) -> Option<MetricsSet> {
351        Some(self.metrics.clone_inner())
352    }
353
354    fn statistics(&self) -> Result<Statistics> {
355        self.input.partition_statistics(None)
356    }
357
358    fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
359        self.input.partition_statistics(None)
360    }
361
362    fn supports_limit_pushdown(&self) -> bool {
363        true
364    }
365
366    /// Tries to swap the projection with its input [`SortPreservingMergeExec`].
367    /// If this is possible, it returns the new [`SortPreservingMergeExec`] whose
368    /// child is a projection. Otherwise, it returns None.
369    fn try_swapping_with_projection(
370        &self,
371        projection: &ProjectionExec,
372    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
373        // If the projection does not narrow the schema, we should not try to push it down.
374        if projection.expr().len() >= projection.input().schema().fields().len() {
375            return Ok(None);
376        }
377
378        let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
379        else {
380            return Ok(None);
381        };
382
383        Ok(Some(Arc::new(
384            SortPreservingMergeExec::new(
385                updated_exprs,
386                make_with_child(projection, self.input())?,
387            )
388            .with_fetch(self.fetch()),
389        )))
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use std::collections::HashSet;
396    use std::fmt::Formatter;
397    use std::pin::Pin;
398    use std::sync::Mutex;
399    use std::task::{ready, Context, Poll, Waker};
400    use std::time::Duration;
401
402    use super::*;
403    use crate::coalesce_batches::CoalesceBatchesExec;
404    use crate::coalesce_partitions::CoalescePartitionsExec;
405    use crate::execution_plan::{Boundedness, EmissionType};
406    use crate::expressions::col;
407    use crate::metrics::{MetricValue, Timestamp};
408    use crate::repartition::RepartitionExec;
409    use crate::sorts::sort::SortExec;
410    use crate::stream::RecordBatchReceiverStream;
411    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
412    use crate::test::TestMemoryExec;
413    use crate::test::{self, assert_is_pending, make_partition};
414    use crate::{collect, common};
415
416    use arrow::array::{
417        ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
418        TimestampNanosecondArray,
419    };
420    use arrow::compute::SortOptions;
421    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
422    use datafusion_common::test_util::batches_to_string;
423    use datafusion_common::{assert_batches_eq, exec_err};
424    use datafusion_common_runtime::SpawnedTask;
425    use datafusion_execution::config::SessionConfig;
426    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
427    use datafusion_execution::RecordBatchStream;
428    use datafusion_physical_expr::expressions::Column;
429    use datafusion_physical_expr::EquivalenceProperties;
430    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
431    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
432
433    use futures::{FutureExt, Stream, StreamExt};
434    use insta::assert_snapshot;
435    use tokio::time::timeout;
436
437    // The number in the function is highly related to the memory limit we are testing
438    // any change of the constant should be aware of
439    fn generate_task_ctx_for_round_robin_tie_breaker() -> Result<Arc<TaskContext>> {
440        let runtime = RuntimeEnvBuilder::new()
441            .with_memory_limit(20_000_000, 1.0)
442            .build_arc()?;
443        let config = SessionConfig::new();
444        let task_ctx = TaskContext::default()
445            .with_runtime(runtime)
446            .with_session_config(config);
447        Ok(Arc::new(task_ctx))
448    }
449    // The number in the function is highly related to the memory limit we are testing,
450    // any change of the constant should be aware of
451    fn generate_spm_for_round_robin_tie_breaker(
452        enable_round_robin_repartition: bool,
453    ) -> Result<Arc<SortPreservingMergeExec>> {
454        let target_batch_size = 12500;
455        let row_size = 12500;
456        let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
457        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
458        let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
459        let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
460
461        let rbs = (0..1024).map(|_| rb.clone()).collect::<Vec<_>>();
462
463        let schema = rb.schema();
464        let sort = [
465            PhysicalSortExpr {
466                expr: col("b", &schema)?,
467                options: Default::default(),
468            },
469            PhysicalSortExpr {
470                expr: col("c", &schema)?,
471                options: Default::default(),
472            },
473        ]
474        .into();
475
476        let repartition_exec = RepartitionExec::try_new(
477            TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
478            Partitioning::RoundRobinBatch(2),
479        )?;
480        let coalesce_batches_exec =
481            CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size);
482        let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec))
483            .with_round_robin_repartition(enable_round_robin_repartition);
484        Ok(Arc::new(spm))
485    }
486
487    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
488    /// Any errors here could indicate unintended changes in tie breaker logic.
489    ///
490    /// Note: If you adjust constants in this test, ensure that memory usage differs
491    /// based on whether the tie breaker is enabled or disabled.
492    #[tokio::test(flavor = "multi_thread")]
493    async fn test_round_robin_tie_breaker_success() -> Result<()> {
494        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
495        let spm = generate_spm_for_round_robin_tie_breaker(true)?;
496        let _collected = collect(spm, task_ctx).await?;
497        Ok(())
498    }
499
500    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
501    /// Any errors here could indicate unintended changes in tie breaker logic.
502    ///
503    /// Note: If you adjust constants in this test, ensure that memory usage differs
504    /// based on whether the tie breaker is enabled or disabled.
505    #[tokio::test(flavor = "multi_thread")]
506    async fn test_round_robin_tie_breaker_fail() -> Result<()> {
507        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
508        let spm = generate_spm_for_round_robin_tie_breaker(false)?;
509        let _err = collect(spm, task_ctx).await.unwrap_err();
510        Ok(())
511    }
512
513    #[tokio::test]
514    async fn test_merge_interleave() {
515        let task_ctx = Arc::new(TaskContext::default());
516        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
517        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
518            Some("a"),
519            Some("c"),
520            Some("e"),
521            Some("g"),
522            Some("j"),
523        ]));
524        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
525        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
526
527        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
528        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
529            Some("b"),
530            Some("d"),
531            Some("f"),
532            Some("h"),
533            Some("j"),
534        ]));
535        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
536        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
537
538        _test_merge(
539            &[vec![b1], vec![b2]],
540            &[
541                "+----+---+-------------------------------+",
542                "| a  | b | c                             |",
543                "+----+---+-------------------------------+",
544                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
545                "| 10 | b | 1970-01-01T00:00:00.000000004 |",
546                "| 2  | c | 1970-01-01T00:00:00.000000007 |",
547                "| 20 | d | 1970-01-01T00:00:00.000000006 |",
548                "| 7  | e | 1970-01-01T00:00:00.000000006 |",
549                "| 70 | f | 1970-01-01T00:00:00.000000002 |",
550                "| 9  | g | 1970-01-01T00:00:00.000000005 |",
551                "| 90 | h | 1970-01-01T00:00:00.000000002 |",
552                "| 30 | j | 1970-01-01T00:00:00.000000006 |", // input b2 before b1
553                "| 3  | j | 1970-01-01T00:00:00.000000008 |",
554                "+----+---+-------------------------------+",
555            ],
556            task_ctx,
557        )
558        .await;
559    }
560
561    #[tokio::test]
562    async fn test_merge_some_overlap() {
563        let task_ctx = Arc::new(TaskContext::default());
564        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
565        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
566            Some("a"),
567            Some("b"),
568            Some("c"),
569            Some("d"),
570            Some("e"),
571        ]));
572        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
573        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
574
575        let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
576        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
577            Some("c"),
578            Some("d"),
579            Some("e"),
580            Some("f"),
581            Some("g"),
582        ]));
583        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
584        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
585
586        _test_merge(
587            &[vec![b1], vec![b2]],
588            &[
589                "+-----+---+-------------------------------+",
590                "| a   | b | c                             |",
591                "+-----+---+-------------------------------+",
592                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
593                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
594                "| 70  | c | 1970-01-01T00:00:00.000000004 |",
595                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
596                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
597                "| 90  | d | 1970-01-01T00:00:00.000000006 |",
598                "| 30  | e | 1970-01-01T00:00:00.000000002 |",
599                "| 3   | e | 1970-01-01T00:00:00.000000008 |",
600                "| 100 | f | 1970-01-01T00:00:00.000000002 |",
601                "| 110 | g | 1970-01-01T00:00:00.000000006 |",
602                "+-----+---+-------------------------------+",
603            ],
604            task_ctx,
605        )
606        .await;
607    }
608
609    #[tokio::test]
610    async fn test_merge_no_overlap() {
611        let task_ctx = Arc::new(TaskContext::default());
612        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
613        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
614            Some("a"),
615            Some("b"),
616            Some("c"),
617            Some("d"),
618            Some("e"),
619        ]));
620        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
621        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
622
623        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
624        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
625            Some("f"),
626            Some("g"),
627            Some("h"),
628            Some("i"),
629            Some("j"),
630        ]));
631        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
632        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
633
634        _test_merge(
635            &[vec![b1], vec![b2]],
636            &[
637                "+----+---+-------------------------------+",
638                "| a  | b | c                             |",
639                "+----+---+-------------------------------+",
640                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
641                "| 2  | b | 1970-01-01T00:00:00.000000007 |",
642                "| 7  | c | 1970-01-01T00:00:00.000000006 |",
643                "| 9  | d | 1970-01-01T00:00:00.000000005 |",
644                "| 3  | e | 1970-01-01T00:00:00.000000008 |",
645                "| 10 | f | 1970-01-01T00:00:00.000000004 |",
646                "| 20 | g | 1970-01-01T00:00:00.000000006 |",
647                "| 70 | h | 1970-01-01T00:00:00.000000002 |",
648                "| 90 | i | 1970-01-01T00:00:00.000000002 |",
649                "| 30 | j | 1970-01-01T00:00:00.000000006 |",
650                "+----+---+-------------------------------+",
651            ],
652            task_ctx,
653        )
654        .await;
655    }
656
657    #[tokio::test]
658    async fn test_merge_three_partitions() {
659        let task_ctx = Arc::new(TaskContext::default());
660        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
661        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
662            Some("a"),
663            Some("b"),
664            Some("c"),
665            Some("d"),
666            Some("f"),
667        ]));
668        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
669        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
670
671        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
672        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
673            Some("e"),
674            Some("g"),
675            Some("h"),
676            Some("i"),
677            Some("j"),
678        ]));
679        let c: ArrayRef =
680            Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
681        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
682
683        let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
684        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
685            Some("f"),
686            Some("g"),
687            Some("h"),
688            Some("i"),
689            Some("j"),
690        ]));
691        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
692        let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
693
694        _test_merge(
695            &[vec![b1], vec![b2], vec![b3]],
696            &[
697                "+-----+---+-------------------------------+",
698                "| a   | b | c                             |",
699                "+-----+---+-------------------------------+",
700                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
701                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
702                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
703                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
704                "| 10  | e | 1970-01-01T00:00:00.000000040 |",
705                "| 100 | f | 1970-01-01T00:00:00.000000004 |",
706                "| 3   | f | 1970-01-01T00:00:00.000000008 |",
707                "| 200 | g | 1970-01-01T00:00:00.000000006 |",
708                "| 20  | g | 1970-01-01T00:00:00.000000060 |",
709                "| 700 | h | 1970-01-01T00:00:00.000000002 |",
710                "| 70  | h | 1970-01-01T00:00:00.000000020 |",
711                "| 900 | i | 1970-01-01T00:00:00.000000002 |",
712                "| 90  | i | 1970-01-01T00:00:00.000000020 |",
713                "| 300 | j | 1970-01-01T00:00:00.000000006 |",
714                "| 30  | j | 1970-01-01T00:00:00.000000060 |",
715                "+-----+---+-------------------------------+",
716            ],
717            task_ctx,
718        )
719        .await;
720    }
721
722    async fn _test_merge(
723        partitions: &[Vec<RecordBatch>],
724        exp: &[&str],
725        context: Arc<TaskContext>,
726    ) {
727        let schema = partitions[0][0].schema();
728        let sort = [
729            PhysicalSortExpr {
730                expr: col("b", &schema).unwrap(),
731                options: Default::default(),
732            },
733            PhysicalSortExpr {
734                expr: col("c", &schema).unwrap(),
735                options: Default::default(),
736            },
737        ]
738        .into();
739        let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
740        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
741
742        let collected = collect(merge, context).await.unwrap();
743        assert_batches_eq!(exp, collected.as_slice());
744    }
745
746    async fn sorted_merge(
747        input: Arc<dyn ExecutionPlan>,
748        sort: LexOrdering,
749        context: Arc<TaskContext>,
750    ) -> RecordBatch {
751        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
752        let mut result = collect(merge, context).await.unwrap();
753        assert_eq!(result.len(), 1);
754        result.remove(0)
755    }
756
757    async fn partition_sort(
758        input: Arc<dyn ExecutionPlan>,
759        sort: LexOrdering,
760        context: Arc<TaskContext>,
761    ) -> RecordBatch {
762        let sort_exec =
763            Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
764        sorted_merge(sort_exec, sort, context).await
765    }
766
767    async fn basic_sort(
768        src: Arc<dyn ExecutionPlan>,
769        sort: LexOrdering,
770        context: Arc<TaskContext>,
771    ) -> RecordBatch {
772        let merge = Arc::new(CoalescePartitionsExec::new(src));
773        let sort_exec = Arc::new(SortExec::new(sort, merge));
774        let mut result = collect(sort_exec, context).await.unwrap();
775        assert_eq!(result.len(), 1);
776        result.remove(0)
777    }
778
779    #[tokio::test]
780    async fn test_partition_sort() -> Result<()> {
781        let task_ctx = Arc::new(TaskContext::default());
782        let partitions = 4;
783        let csv = test::scan_partitioned(partitions);
784        let schema = csv.schema();
785
786        let sort: LexOrdering = [PhysicalSortExpr {
787            expr: col("i", &schema)?,
788            options: SortOptions {
789                descending: true,
790                nulls_first: true,
791            },
792        }]
793        .into();
794
795        let basic =
796            basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
797        let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
798
799        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
800            .unwrap()
801            .to_string();
802        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
803            .unwrap()
804            .to_string();
805
806        assert_eq!(
807            basic, partition,
808            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
809        );
810
811        Ok(())
812    }
813
814    // Split the provided record batch into multiple batch_size record batches
815    fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
816        let batches = sorted.num_rows().div_ceil(batch_size);
817
818        // Split the sorted RecordBatch into multiple
819        (0..batches)
820            .map(|batch_idx| {
821                let columns = (0..sorted.num_columns())
822                    .map(|column_idx| {
823                        let length =
824                            batch_size.min(sorted.num_rows() - batch_idx * batch_size);
825
826                        sorted
827                            .column(column_idx)
828                            .slice(batch_idx * batch_size, length)
829                    })
830                    .collect();
831
832                RecordBatch::try_new(sorted.schema(), columns).unwrap()
833            })
834            .collect()
835    }
836
837    async fn sorted_partitioned_input(
838        sort: LexOrdering,
839        sizes: &[usize],
840        context: Arc<TaskContext>,
841    ) -> Result<Arc<dyn ExecutionPlan>> {
842        let partitions = 4;
843        let csv = test::scan_partitioned(partitions);
844
845        let sorted = basic_sort(csv, sort, context).await;
846        let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
847
848        TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
849    }
850
851    #[tokio::test]
852    async fn test_partition_sort_streaming_input() -> Result<()> {
853        let task_ctx = Arc::new(TaskContext::default());
854        let schema = make_partition(11).schema();
855        let sort: LexOrdering = [PhysicalSortExpr {
856            expr: col("i", &schema)?,
857            options: Default::default(),
858        }]
859        .into();
860
861        let input =
862            sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
863                .await?;
864        let basic =
865            basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
866        let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
867
868        assert_eq!(basic.num_rows(), 1200);
869        assert_eq!(partition.num_rows(), 1200);
870
871        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
872        let partition =
873            arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
874
875        assert_eq!(basic, partition);
876
877        Ok(())
878    }
879
880    #[tokio::test]
881    async fn test_partition_sort_streaming_input_output() -> Result<()> {
882        let schema = make_partition(11).schema();
883        let sort: LexOrdering = [PhysicalSortExpr {
884            expr: col("i", &schema)?,
885            options: Default::default(),
886        }]
887        .into();
888
889        // Test streaming with default batch size
890        let task_ctx = Arc::new(TaskContext::default());
891        let input =
892            sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
893                .await?;
894        let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
895
896        // batch size of 23
897        let task_ctx = TaskContext::default()
898            .with_session_config(SessionConfig::new().with_batch_size(23));
899        let task_ctx = Arc::new(task_ctx);
900
901        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
902        let merged = collect(merge, task_ctx).await?;
903
904        assert_eq!(merged.len(), 53);
905        assert_eq!(basic.num_rows(), 1200);
906        assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
907
908        let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
909        let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
910
911        assert_eq!(basic, partition);
912
913        Ok(())
914    }
915
916    #[tokio::test]
917    async fn test_nulls() {
918        let task_ctx = Arc::new(TaskContext::default());
919        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
920        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
921            None,
922            Some("a"),
923            Some("b"),
924            Some("d"),
925            Some("e"),
926        ]));
927        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
928            Some(8),
929            None,
930            Some(6),
931            None,
932            Some(4),
933        ]));
934        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
935
936        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
937        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
938            None,
939            Some("b"),
940            Some("g"),
941            Some("h"),
942            Some("i"),
943        ]));
944        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
945            Some(8),
946            None,
947            Some(5),
948            None,
949            Some(4),
950        ]));
951        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
952        let schema = b1.schema();
953
954        let sort = [
955            PhysicalSortExpr {
956                expr: col("b", &schema).unwrap(),
957                options: SortOptions {
958                    descending: false,
959                    nulls_first: true,
960                },
961            },
962            PhysicalSortExpr {
963                expr: col("c", &schema).unwrap(),
964                options: SortOptions {
965                    descending: false,
966                    nulls_first: false,
967                },
968            },
969        ]
970        .into();
971        let exec =
972            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
973        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
974
975        let collected = collect(merge, task_ctx).await.unwrap();
976        assert_eq!(collected.len(), 1);
977
978        assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
979            +---+---+-------------------------------+
980            | a | b | c                             |
981            +---+---+-------------------------------+
982            | 1 |   | 1970-01-01T00:00:00.000000008 |
983            | 1 |   | 1970-01-01T00:00:00.000000008 |
984            | 2 | a |                               |
985            | 7 | b | 1970-01-01T00:00:00.000000006 |
986            | 2 | b |                               |
987            | 9 | d |                               |
988            | 3 | e | 1970-01-01T00:00:00.000000004 |
989            | 3 | g | 1970-01-01T00:00:00.000000005 |
990            | 4 | h |                               |
991            | 5 | i | 1970-01-01T00:00:00.000000004 |
992            +---+---+-------------------------------+
993            "#);
994    }
995
996    #[tokio::test]
997    async fn test_sort_merge_single_partition_with_fetch() {
998        let task_ctx = Arc::new(TaskContext::default());
999        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1000        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1001        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1002        let schema = batch.schema();
1003
1004        let sort = [PhysicalSortExpr {
1005            expr: col("b", &schema).unwrap(),
1006            options: SortOptions {
1007                descending: false,
1008                nulls_first: true,
1009            },
1010        }]
1011        .into();
1012        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1013        let merge =
1014            Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1015
1016        let collected = collect(merge, task_ctx).await.unwrap();
1017        assert_eq!(collected.len(), 1);
1018
1019        assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1020            +---+---+
1021            | a | b |
1022            +---+---+
1023            | 1 | a |
1024            | 2 | b |
1025            +---+---+
1026            "#);
1027    }
1028
1029    #[tokio::test]
1030    async fn test_sort_merge_single_partition_without_fetch() {
1031        let task_ctx = Arc::new(TaskContext::default());
1032        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1033        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1034        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1035        let schema = batch.schema();
1036
1037        let sort = [PhysicalSortExpr {
1038            expr: col("b", &schema).unwrap(),
1039            options: SortOptions {
1040                descending: false,
1041                nulls_first: true,
1042            },
1043        }]
1044        .into();
1045        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1046        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1047
1048        let collected = collect(merge, task_ctx).await.unwrap();
1049        assert_eq!(collected.len(), 1);
1050
1051        assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1052            +---+---+
1053            | a | b |
1054            +---+---+
1055            | 1 | a |
1056            | 2 | b |
1057            | 7 | c |
1058            | 9 | d |
1059            | 3 | e |
1060            +---+---+
1061            "#);
1062    }
1063
1064    #[tokio::test]
1065    async fn test_async() -> Result<()> {
1066        let task_ctx = Arc::new(TaskContext::default());
1067        let schema = make_partition(11).schema();
1068        let sort: LexOrdering = [PhysicalSortExpr {
1069            expr: col("i", &schema).unwrap(),
1070            options: SortOptions::default(),
1071        }]
1072        .into();
1073
1074        let batches =
1075            sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1076                .await?;
1077
1078        let partition_count = batches.output_partitioning().partition_count();
1079        let mut streams = Vec::with_capacity(partition_count);
1080
1081        for partition in 0..partition_count {
1082            let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1083
1084            let sender = builder.tx();
1085
1086            let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1087            builder.spawn(async move {
1088                while let Some(batch) = stream.next().await {
1089                    sender.send(batch).await.unwrap();
1090                    // This causes the MergeStream to wait for more input
1091                    tokio::time::sleep(Duration::from_millis(10)).await;
1092                }
1093
1094                Ok(())
1095            });
1096
1097            streams.push(builder.build());
1098        }
1099
1100        let metrics = ExecutionPlanMetricsSet::new();
1101        let reservation =
1102            MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1103
1104        let fetch = None;
1105        let merge_stream = StreamingMergeBuilder::new()
1106            .with_streams(streams)
1107            .with_schema(batches.schema())
1108            .with_expressions(&sort)
1109            .with_metrics(BaselineMetrics::new(&metrics, 0))
1110            .with_batch_size(task_ctx.session_config().batch_size())
1111            .with_fetch(fetch)
1112            .with_reservation(reservation)
1113            .build()?;
1114
1115        let mut merged = common::collect(merge_stream).await.unwrap();
1116
1117        assert_eq!(merged.len(), 1);
1118        let merged = merged.remove(0);
1119        let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1120
1121        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1122            .unwrap()
1123            .to_string();
1124        let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1125            .unwrap()
1126            .to_string();
1127
1128        assert_eq!(
1129            basic, partition,
1130            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1131        );
1132
1133        Ok(())
1134    }
1135
1136    #[tokio::test]
1137    async fn test_merge_metrics() {
1138        let task_ctx = Arc::new(TaskContext::default());
1139        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1140        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1141        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1142
1143        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1144        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1145        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1146
1147        let schema = b1.schema();
1148        let sort = [PhysicalSortExpr {
1149            expr: col("b", &schema).unwrap(),
1150            options: Default::default(),
1151        }]
1152        .into();
1153        let exec =
1154            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1155        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1156
1157        let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1158            .await
1159            .unwrap();
1160        assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1161            +----+---+
1162            | a  | b |
1163            +----+---+
1164            | 1  | a |
1165            | 10 | b |
1166            | 2  | c |
1167            | 20 | d |
1168            +----+---+
1169            "#);
1170
1171        // Now, validate metrics
1172        let metrics = merge.metrics().unwrap();
1173
1174        assert_eq!(metrics.output_rows().unwrap(), 4);
1175        assert!(metrics.elapsed_compute().unwrap() > 0);
1176
1177        let mut saw_start = false;
1178        let mut saw_end = false;
1179        metrics.iter().for_each(|m| match m.value() {
1180            MetricValue::StartTimestamp(ts) => {
1181                saw_start = true;
1182                assert!(nanos_from_timestamp(ts) > 0);
1183            }
1184            MetricValue::EndTimestamp(ts) => {
1185                saw_end = true;
1186                assert!(nanos_from_timestamp(ts) > 0);
1187            }
1188            _ => {}
1189        });
1190
1191        assert!(saw_start);
1192        assert!(saw_end);
1193    }
1194
1195    fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1196        ts.value().unwrap().timestamp_nanos_opt().unwrap()
1197    }
1198
1199    #[tokio::test]
1200    async fn test_drop_cancel() -> Result<()> {
1201        let task_ctx = Arc::new(TaskContext::default());
1202        let schema =
1203            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1204
1205        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1206        let refs = blocking_exec.refs();
1207        let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1208            [PhysicalSortExpr {
1209                expr: col("a", &schema)?,
1210                options: SortOptions::default(),
1211            }]
1212            .into(),
1213            blocking_exec,
1214        ));
1215
1216        let fut = collect(sort_preserving_merge_exec, task_ctx);
1217        let mut fut = fut.boxed();
1218
1219        assert_is_pending(&mut fut);
1220        drop(fut);
1221        assert_strong_count_converges_to_zero(refs).await;
1222
1223        Ok(())
1224    }
1225
1226    #[tokio::test]
1227    async fn test_stable_sort() {
1228        let task_ctx = Arc::new(TaskContext::default());
1229
1230        // Create record batches like:
1231        // batch_number |value
1232        // -------------+------
1233        //    1         | A
1234        //    1         | B
1235        //
1236        // Ensure that the output is in the same order the batches were fed
1237        let partitions: Vec<Vec<RecordBatch>> = (0..10)
1238            .map(|batch_number| {
1239                let batch_number: Int32Array =
1240                    vec![Some(batch_number), Some(batch_number)]
1241                        .into_iter()
1242                        .collect();
1243                let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1244
1245                let batch = RecordBatch::try_from_iter(vec![
1246                    ("batch_number", Arc::new(batch_number) as ArrayRef),
1247                    ("value", Arc::new(value) as ArrayRef),
1248                ])
1249                .unwrap();
1250
1251                vec![batch]
1252            })
1253            .collect();
1254
1255        let schema = partitions[0][0].schema();
1256
1257        let sort = [PhysicalSortExpr {
1258            expr: col("value", &schema).unwrap(),
1259            options: SortOptions {
1260                descending: false,
1261                nulls_first: true,
1262            },
1263        }]
1264        .into();
1265
1266        let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1267        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1268
1269        let collected = collect(merge, task_ctx).await.unwrap();
1270        assert_eq!(collected.len(), 1);
1271
1272        // Expect the data to be sorted first by "batch_number" (because
1273        // that was the order it was fed in, even though only "value"
1274        // is in the sort key)
1275        assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1276                +--------------+-------+
1277                | batch_number | value |
1278                +--------------+-------+
1279                | 0            | A     |
1280                | 1            | A     |
1281                | 2            | A     |
1282                | 3            | A     |
1283                | 4            | A     |
1284                | 5            | A     |
1285                | 6            | A     |
1286                | 7            | A     |
1287                | 8            | A     |
1288                | 9            | A     |
1289                | 0            | B     |
1290                | 1            | B     |
1291                | 2            | B     |
1292                | 3            | B     |
1293                | 4            | B     |
1294                | 5            | B     |
1295                | 6            | B     |
1296                | 7            | B     |
1297                | 8            | B     |
1298                | 9            | B     |
1299                +--------------+-------+
1300            "#);
1301    }
1302
1303    #[derive(Debug)]
1304    struct CongestionState {
1305        wakers: Vec<Waker>,
1306        unpolled_partitions: HashSet<usize>,
1307    }
1308
1309    #[derive(Debug)]
1310    struct Congestion {
1311        congestion_state: Mutex<CongestionState>,
1312    }
1313
1314    impl Congestion {
1315        fn new(partition_count: usize) -> Self {
1316            Congestion {
1317                congestion_state: Mutex::new(CongestionState {
1318                    wakers: vec![],
1319                    unpolled_partitions: (0usize..partition_count).collect(),
1320                }),
1321            }
1322        }
1323
1324        fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1325            let mut state = self.congestion_state.lock().unwrap();
1326
1327            state.unpolled_partitions.remove(&partition);
1328
1329            if state.unpolled_partitions.is_empty() {
1330                state.wakers.iter().for_each(|w| w.wake_by_ref());
1331                state.wakers.clear();
1332                Poll::Ready(())
1333            } else {
1334                state.wakers.push(cx.waker().clone());
1335                Poll::Pending
1336            }
1337        }
1338    }
1339
1340    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1341    /// partition is exhausted from the start, and if it is polled more than one, it panics.
1342    #[derive(Debug, Clone)]
1343    struct CongestedExec {
1344        schema: Schema,
1345        cache: PlanProperties,
1346        congestion: Arc<Congestion>,
1347    }
1348
1349    impl CongestedExec {
1350        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1351            let columns = schema
1352                .fields
1353                .iter()
1354                .enumerate()
1355                .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1356                .collect::<Vec<_>>();
1357            let mut eq_properties = EquivalenceProperties::new(schema);
1358            eq_properties.add_ordering(
1359                columns
1360                    .iter()
1361                    .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1362            );
1363            PlanProperties::new(
1364                eq_properties,
1365                Partitioning::Hash(columns, 3),
1366                EmissionType::Incremental,
1367                Boundedness::Unbounded {
1368                    requires_infinite_memory: false,
1369                },
1370            )
1371        }
1372    }
1373
1374    impl ExecutionPlan for CongestedExec {
1375        fn name(&self) -> &'static str {
1376            Self::static_name()
1377        }
1378        fn as_any(&self) -> &dyn Any {
1379            self
1380        }
1381        fn properties(&self) -> &PlanProperties {
1382            &self.cache
1383        }
1384        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1385            vec![]
1386        }
1387        fn with_new_children(
1388            self: Arc<Self>,
1389            _: Vec<Arc<dyn ExecutionPlan>>,
1390        ) -> Result<Arc<dyn ExecutionPlan>> {
1391            Ok(self)
1392        }
1393        fn execute(
1394            &self,
1395            partition: usize,
1396            _context: Arc<TaskContext>,
1397        ) -> Result<SendableRecordBatchStream> {
1398            Ok(Box::pin(CongestedStream {
1399                schema: Arc::new(self.schema.clone()),
1400                none_polled_once: false,
1401                congestion: Arc::clone(&self.congestion),
1402                partition,
1403            }))
1404        }
1405    }
1406
1407    impl DisplayAs for CongestedExec {
1408        fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1409            match t {
1410                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1411                    write!(f, "CongestedExec",).unwrap()
1412                }
1413                DisplayFormatType::TreeRender => {
1414                    // TODO: collect info
1415                    write!(f, "").unwrap()
1416                }
1417            }
1418            Ok(())
1419        }
1420    }
1421
1422    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1423    /// partition is exhausted from the start, and if it is polled more than once, it panics.
1424    #[derive(Debug)]
1425    pub struct CongestedStream {
1426        schema: SchemaRef,
1427        none_polled_once: bool,
1428        congestion: Arc<Congestion>,
1429        partition: usize,
1430    }
1431
1432    impl Stream for CongestedStream {
1433        type Item = Result<RecordBatch>;
1434        fn poll_next(
1435            mut self: Pin<&mut Self>,
1436            cx: &mut Context<'_>,
1437        ) -> Poll<Option<Self::Item>> {
1438            match self.partition {
1439                0 => {
1440                    let _ = self.congestion.check_congested(self.partition, cx);
1441                    if self.none_polled_once {
1442                        panic!("Exhausted stream is polled more than once")
1443                    } else {
1444                        self.none_polled_once = true;
1445                        Poll::Ready(None)
1446                    }
1447                }
1448                _ => {
1449                    ready!(self.congestion.check_congested(self.partition, cx));
1450                    Poll::Ready(None)
1451                }
1452            }
1453        }
1454    }
1455
1456    impl RecordBatchStream for CongestedStream {
1457        fn schema(&self) -> SchemaRef {
1458            Arc::clone(&self.schema)
1459        }
1460    }
1461
1462    #[tokio::test]
1463    async fn test_spm_congestion() -> Result<()> {
1464        let task_ctx = Arc::new(TaskContext::default());
1465        let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1466        let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1467        let &partition_count = match properties.output_partitioning() {
1468            Partitioning::RoundRobinBatch(partitions) => partitions,
1469            Partitioning::Hash(_, partitions) => partitions,
1470            Partitioning::UnknownPartitioning(partitions) => partitions,
1471        };
1472        let source = CongestedExec {
1473            schema: schema.clone(),
1474            cache: properties,
1475            congestion: Arc::new(Congestion::new(partition_count)),
1476        };
1477        let spm = SortPreservingMergeExec::new(
1478            [PhysicalSortExpr::new_default(Arc::new(Column::new(
1479                "c1", 0,
1480            )))]
1481            .into(),
1482            Arc::new(source),
1483        );
1484        let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1485
1486        let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1487        match result {
1488            Ok(Ok(Ok(_batches))) => Ok(()),
1489            Ok(Ok(Err(e))) => Err(e),
1490            Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1491            Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1492        }
1493    }
1494}