1use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{ProjectionExec, make_with_child, update_ordering};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31};
32
33use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
34use datafusion_execution::TaskContext;
35use datafusion_execution::memory_pool::MemoryConsumer;
36use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
37
38use crate::execution_plan::{EvaluationType, SchedulingType};
39use log::{debug, trace};
40
41#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87 input: Arc<dyn ExecutionPlan>,
89 expr: LexOrdering,
91 metrics: ExecutionPlanMetricsSet,
93 fetch: Option<usize>,
95 cache: PlanProperties,
97 enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104 pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106 let cache = Self::compute_properties(&input, expr.clone());
107 Self {
108 input,
109 expr,
110 metrics: ExecutionPlanMetricsSet::new(),
111 fetch: None,
112 cache,
113 enable_round_robin_repartition: true,
114 }
115 }
116
117 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119 self.fetch = fetch;
120 self
121 }
122
123 pub fn with_round_robin_repartition(
133 mut self,
134 enable_round_robin_repartition: bool,
135 ) -> Self {
136 self.enable_round_robin_repartition = enable_round_robin_repartition;
137 self
138 }
139
140 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142 &self.input
143 }
144
145 pub fn expr(&self) -> &LexOrdering {
147 &self.expr
148 }
149
150 pub fn fetch(&self) -> Option<usize> {
152 self.fetch
153 }
154
155 fn compute_properties(
158 input: &Arc<dyn ExecutionPlan>,
159 ordering: LexOrdering,
160 ) -> PlanProperties {
161 let input_partitions = input.output_partitioning().partition_count();
162 let (drive, scheduling) = if input_partitions > 1 {
163 (EvaluationType::Eager, SchedulingType::Cooperative)
164 } else {
165 (
166 input.properties().evaluation_type,
167 input.properties().scheduling_type,
168 )
169 };
170
171 let mut eq_properties = input.equivalence_properties().clone();
172 eq_properties.clear_per_partition_constants();
173 eq_properties.add_ordering(ordering);
174 PlanProperties::new(
175 eq_properties, Partitioning::UnknownPartitioning(1), input.pipeline_behavior(), input.boundedness(), )
180 .with_evaluation_type(drive)
181 .with_scheduling_type(scheduling)
182 }
183}
184
185impl DisplayAs for SortPreservingMergeExec {
186 fn fmt_as(
187 &self,
188 t: DisplayFormatType,
189 f: &mut std::fmt::Formatter,
190 ) -> std::fmt::Result {
191 match t {
192 DisplayFormatType::Default | DisplayFormatType::Verbose => {
193 write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
194 if let Some(fetch) = self.fetch {
195 write!(f, ", fetch={fetch}")?;
196 };
197
198 Ok(())
199 }
200 DisplayFormatType::TreeRender => {
201 if let Some(fetch) = self.fetch {
202 writeln!(f, "limit={fetch}")?;
203 };
204
205 for (i, e) in self.expr().iter().enumerate() {
206 e.fmt_sql(f)?;
207 if i != self.expr().len() - 1 {
208 write!(f, ", ")?;
209 }
210 }
211
212 Ok(())
213 }
214 }
215 }
216}
217
218impl ExecutionPlan for SortPreservingMergeExec {
219 fn name(&self) -> &'static str {
220 "SortPreservingMergeExec"
221 }
222
223 fn as_any(&self) -> &dyn Any {
225 self
226 }
227
228 fn properties(&self) -> &PlanProperties {
229 &self.cache
230 }
231
232 fn fetch(&self) -> Option<usize> {
233 self.fetch
234 }
235
236 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
238 Some(Arc::new(Self {
239 input: Arc::clone(&self.input),
240 expr: self.expr.clone(),
241 metrics: self.metrics.clone(),
242 fetch: limit,
243 cache: self.cache.clone(),
244 enable_round_robin_repartition: true,
245 }))
246 }
247
248 fn required_input_distribution(&self) -> Vec<Distribution> {
249 vec![Distribution::UnspecifiedDistribution]
250 }
251
252 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
253 vec![false]
254 }
255
256 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
257 vec![Some(OrderingRequirements::from(self.expr.clone()))]
258 }
259
260 fn maintains_input_order(&self) -> Vec<bool> {
261 vec![true]
262 }
263
264 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
265 vec![&self.input]
266 }
267
268 fn with_new_children(
269 self: Arc<Self>,
270 children: Vec<Arc<dyn ExecutionPlan>>,
271 ) -> Result<Arc<dyn ExecutionPlan>> {
272 Ok(Arc::new(
273 SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
274 .with_fetch(self.fetch),
275 ))
276 }
277
278 fn execute(
279 &self,
280 partition: usize,
281 context: Arc<TaskContext>,
282 ) -> Result<SendableRecordBatchStream> {
283 trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
284 assert_eq_or_internal_err!(
285 partition,
286 0,
287 "SortPreservingMergeExec invalid partition {partition}"
288 );
289
290 let input_partitions = self.input.output_partitioning().partition_count();
291 trace!(
292 "Number of input partitions of SortPreservingMergeExec::execute: {input_partitions}"
293 );
294 let schema = self.schema();
295
296 let reservation =
297 MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
298 .register(&context.runtime_env().memory_pool);
299
300 match input_partitions {
301 0 => internal_err!(
302 "SortPreservingMergeExec requires at least one input partition"
303 ),
304 1 => match self.fetch {
305 Some(fetch) => {
306 let stream = self.input.execute(0, context)?;
307 debug!(
308 "Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"
309 );
310 Ok(Box::pin(LimitStream::new(
311 stream,
312 0,
313 Some(fetch),
314 BaselineMetrics::new(&self.metrics, partition),
315 )))
316 }
317 None => {
318 let stream = self.input.execute(0, context);
319 debug!(
320 "Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"
321 );
322 stream
323 }
324 },
325 _ => {
326 let receivers = (0..input_partitions)
327 .map(|partition| {
328 let stream =
329 self.input.execute(partition, Arc::clone(&context))?;
330 Ok(spawn_buffered(stream, 1))
331 })
332 .collect::<Result<_>>()?;
333
334 debug!(
335 "Done setting up sender-receiver for SortPreservingMergeExec::execute"
336 );
337
338 let result = StreamingMergeBuilder::new()
339 .with_streams(receivers)
340 .with_schema(schema)
341 .with_expressions(&self.expr)
342 .with_metrics(BaselineMetrics::new(&self.metrics, partition))
343 .with_batch_size(context.session_config().batch_size())
344 .with_fetch(self.fetch)
345 .with_reservation(reservation)
346 .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
347 .build()?;
348
349 debug!(
350 "Got stream result from SortPreservingMergeStream::new_from_receivers"
351 );
352
353 Ok(result)
354 }
355 }
356 }
357
358 fn metrics(&self) -> Option<MetricsSet> {
359 Some(self.metrics.clone_inner())
360 }
361
362 fn statistics(&self) -> Result<Statistics> {
363 self.input.partition_statistics(None)
364 }
365
366 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
367 self.input.partition_statistics(None)
368 }
369
370 fn supports_limit_pushdown(&self) -> bool {
371 true
372 }
373
374 fn try_swapping_with_projection(
378 &self,
379 projection: &ProjectionExec,
380 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
381 if projection.expr().len() >= projection.input().schema().fields().len() {
383 return Ok(None);
384 }
385
386 let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
387 else {
388 return Ok(None);
389 };
390
391 Ok(Some(Arc::new(
392 SortPreservingMergeExec::new(
393 updated_exprs,
394 make_with_child(projection, self.input())?,
395 )
396 .with_fetch(self.fetch()),
397 )))
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use std::collections::HashSet;
404 use std::fmt::Formatter;
405 use std::pin::Pin;
406 use std::sync::Mutex;
407 use std::task::{Context, Poll, Waker, ready};
408 use std::time::Duration;
409
410 use super::*;
411 use crate::coalesce_batches::CoalesceBatchesExec;
412 use crate::coalesce_partitions::CoalescePartitionsExec;
413 use crate::execution_plan::{Boundedness, EmissionType};
414 use crate::expressions::col;
415 use crate::metrics::{MetricValue, Timestamp};
416 use crate::repartition::RepartitionExec;
417 use crate::sorts::sort::SortExec;
418 use crate::stream::RecordBatchReceiverStream;
419 use crate::test::TestMemoryExec;
420 use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
421 use crate::test::{self, assert_is_pending, make_partition};
422 use crate::{collect, common};
423
424 use arrow::array::{
425 ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
426 TimestampNanosecondArray,
427 };
428 use arrow::compute::SortOptions;
429 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
430 use datafusion_common::test_util::batches_to_string;
431 use datafusion_common::{assert_batches_eq, exec_err};
432 use datafusion_common_runtime::SpawnedTask;
433 use datafusion_execution::RecordBatchStream;
434 use datafusion_execution::config::SessionConfig;
435 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
436 use datafusion_physical_expr::EquivalenceProperties;
437 use datafusion_physical_expr::expressions::Column;
438 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
439 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
440
441 use futures::{FutureExt, Stream, StreamExt};
442 use insta::assert_snapshot;
443 use tokio::time::timeout;
444
445 fn generate_task_ctx_for_round_robin_tie_breaker() -> Result<Arc<TaskContext>> {
448 let runtime = RuntimeEnvBuilder::new()
449 .with_memory_limit(20_000_000, 1.0)
450 .build_arc()?;
451 let config = SessionConfig::new();
452 let task_ctx = TaskContext::default()
453 .with_runtime(runtime)
454 .with_session_config(config);
455 Ok(Arc::new(task_ctx))
456 }
457 fn generate_spm_for_round_robin_tie_breaker(
460 enable_round_robin_repartition: bool,
461 ) -> Result<Arc<SortPreservingMergeExec>> {
462 let target_batch_size = 12500;
463 let row_size = 12500;
464 let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
465 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
466 let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
467 let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
468
469 let rbs = (0..1024).map(|_| rb.clone()).collect::<Vec<_>>();
470
471 let schema = rb.schema();
472 let sort = [
473 PhysicalSortExpr {
474 expr: col("b", &schema)?,
475 options: Default::default(),
476 },
477 PhysicalSortExpr {
478 expr: col("c", &schema)?,
479 options: Default::default(),
480 },
481 ]
482 .into();
483
484 let repartition_exec = RepartitionExec::try_new(
485 TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
486 Partitioning::RoundRobinBatch(2),
487 )?;
488 let coalesce_batches_exec =
489 CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size);
490 let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec))
491 .with_round_robin_repartition(enable_round_robin_repartition);
492 Ok(Arc::new(spm))
493 }
494
495 #[tokio::test(flavor = "multi_thread")]
501 async fn test_round_robin_tie_breaker_success() -> Result<()> {
502 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
503 let spm = generate_spm_for_round_robin_tie_breaker(true)?;
504 let _collected = collect(spm, task_ctx).await?;
505 Ok(())
506 }
507
508 #[tokio::test(flavor = "multi_thread")]
514 async fn test_round_robin_tie_breaker_fail() -> Result<()> {
515 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
516 let spm = generate_spm_for_round_robin_tie_breaker(false)?;
517 let _err = collect(spm, task_ctx).await.unwrap_err();
518 Ok(())
519 }
520
521 #[tokio::test]
522 async fn test_merge_interleave() {
523 let task_ctx = Arc::new(TaskContext::default());
524 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
525 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
526 Some("a"),
527 Some("c"),
528 Some("e"),
529 Some("g"),
530 Some("j"),
531 ]));
532 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
533 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
534
535 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
536 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
537 Some("b"),
538 Some("d"),
539 Some("f"),
540 Some("h"),
541 Some("j"),
542 ]));
543 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
544 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
545
546 _test_merge(
547 &[vec![b1], vec![b2]],
548 &[
549 "+----+---+-------------------------------+",
550 "| a | b | c |",
551 "+----+---+-------------------------------+",
552 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
553 "| 10 | b | 1970-01-01T00:00:00.000000004 |",
554 "| 2 | c | 1970-01-01T00:00:00.000000007 |",
555 "| 20 | d | 1970-01-01T00:00:00.000000006 |",
556 "| 7 | e | 1970-01-01T00:00:00.000000006 |",
557 "| 70 | f | 1970-01-01T00:00:00.000000002 |",
558 "| 9 | g | 1970-01-01T00:00:00.000000005 |",
559 "| 90 | h | 1970-01-01T00:00:00.000000002 |",
560 "| 30 | j | 1970-01-01T00:00:00.000000006 |", "| 3 | j | 1970-01-01T00:00:00.000000008 |",
562 "+----+---+-------------------------------+",
563 ],
564 task_ctx,
565 )
566 .await;
567 }
568
569 #[tokio::test]
570 async fn test_merge_some_overlap() {
571 let task_ctx = Arc::new(TaskContext::default());
572 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
573 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
574 Some("a"),
575 Some("b"),
576 Some("c"),
577 Some("d"),
578 Some("e"),
579 ]));
580 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
581 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
582
583 let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
584 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
585 Some("c"),
586 Some("d"),
587 Some("e"),
588 Some("f"),
589 Some("g"),
590 ]));
591 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
592 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
593
594 _test_merge(
595 &[vec![b1], vec![b2]],
596 &[
597 "+-----+---+-------------------------------+",
598 "| a | b | c |",
599 "+-----+---+-------------------------------+",
600 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
601 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
602 "| 70 | c | 1970-01-01T00:00:00.000000004 |",
603 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
604 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
605 "| 90 | d | 1970-01-01T00:00:00.000000006 |",
606 "| 30 | e | 1970-01-01T00:00:00.000000002 |",
607 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
608 "| 100 | f | 1970-01-01T00:00:00.000000002 |",
609 "| 110 | g | 1970-01-01T00:00:00.000000006 |",
610 "+-----+---+-------------------------------+",
611 ],
612 task_ctx,
613 )
614 .await;
615 }
616
617 #[tokio::test]
618 async fn test_merge_no_overlap() {
619 let task_ctx = Arc::new(TaskContext::default());
620 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
621 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
622 Some("a"),
623 Some("b"),
624 Some("c"),
625 Some("d"),
626 Some("e"),
627 ]));
628 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
629 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
630
631 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
632 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
633 Some("f"),
634 Some("g"),
635 Some("h"),
636 Some("i"),
637 Some("j"),
638 ]));
639 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
640 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
641
642 _test_merge(
643 &[vec![b1], vec![b2]],
644 &[
645 "+----+---+-------------------------------+",
646 "| a | b | c |",
647 "+----+---+-------------------------------+",
648 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
649 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
650 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
651 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
652 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
653 "| 10 | f | 1970-01-01T00:00:00.000000004 |",
654 "| 20 | g | 1970-01-01T00:00:00.000000006 |",
655 "| 70 | h | 1970-01-01T00:00:00.000000002 |",
656 "| 90 | i | 1970-01-01T00:00:00.000000002 |",
657 "| 30 | j | 1970-01-01T00:00:00.000000006 |",
658 "+----+---+-------------------------------+",
659 ],
660 task_ctx,
661 )
662 .await;
663 }
664
665 #[tokio::test]
666 async fn test_merge_three_partitions() {
667 let task_ctx = Arc::new(TaskContext::default());
668 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
669 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
670 Some("a"),
671 Some("b"),
672 Some("c"),
673 Some("d"),
674 Some("f"),
675 ]));
676 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
677 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
678
679 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
680 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
681 Some("e"),
682 Some("g"),
683 Some("h"),
684 Some("i"),
685 Some("j"),
686 ]));
687 let c: ArrayRef =
688 Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
689 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
690
691 let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
692 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
693 Some("f"),
694 Some("g"),
695 Some("h"),
696 Some("i"),
697 Some("j"),
698 ]));
699 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
700 let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
701
702 _test_merge(
703 &[vec![b1], vec![b2], vec![b3]],
704 &[
705 "+-----+---+-------------------------------+",
706 "| a | b | c |",
707 "+-----+---+-------------------------------+",
708 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
709 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
710 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
711 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
712 "| 10 | e | 1970-01-01T00:00:00.000000040 |",
713 "| 100 | f | 1970-01-01T00:00:00.000000004 |",
714 "| 3 | f | 1970-01-01T00:00:00.000000008 |",
715 "| 200 | g | 1970-01-01T00:00:00.000000006 |",
716 "| 20 | g | 1970-01-01T00:00:00.000000060 |",
717 "| 700 | h | 1970-01-01T00:00:00.000000002 |",
718 "| 70 | h | 1970-01-01T00:00:00.000000020 |",
719 "| 900 | i | 1970-01-01T00:00:00.000000002 |",
720 "| 90 | i | 1970-01-01T00:00:00.000000020 |",
721 "| 300 | j | 1970-01-01T00:00:00.000000006 |",
722 "| 30 | j | 1970-01-01T00:00:00.000000060 |",
723 "+-----+---+-------------------------------+",
724 ],
725 task_ctx,
726 )
727 .await;
728 }
729
730 async fn _test_merge(
731 partitions: &[Vec<RecordBatch>],
732 exp: &[&str],
733 context: Arc<TaskContext>,
734 ) {
735 let schema = partitions[0][0].schema();
736 let sort = [
737 PhysicalSortExpr {
738 expr: col("b", &schema).unwrap(),
739 options: Default::default(),
740 },
741 PhysicalSortExpr {
742 expr: col("c", &schema).unwrap(),
743 options: Default::default(),
744 },
745 ]
746 .into();
747 let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
748 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
749
750 let collected = collect(merge, context).await.unwrap();
751 assert_batches_eq!(exp, collected.as_slice());
752 }
753
754 async fn sorted_merge(
755 input: Arc<dyn ExecutionPlan>,
756 sort: LexOrdering,
757 context: Arc<TaskContext>,
758 ) -> RecordBatch {
759 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
760 let mut result = collect(merge, context).await.unwrap();
761 assert_eq!(result.len(), 1);
762 result.remove(0)
763 }
764
765 async fn partition_sort(
766 input: Arc<dyn ExecutionPlan>,
767 sort: LexOrdering,
768 context: Arc<TaskContext>,
769 ) -> RecordBatch {
770 let sort_exec =
771 Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
772 sorted_merge(sort_exec, sort, context).await
773 }
774
775 async fn basic_sort(
776 src: Arc<dyn ExecutionPlan>,
777 sort: LexOrdering,
778 context: Arc<TaskContext>,
779 ) -> RecordBatch {
780 let merge = Arc::new(CoalescePartitionsExec::new(src));
781 let sort_exec = Arc::new(SortExec::new(sort, merge));
782 let mut result = collect(sort_exec, context).await.unwrap();
783 assert_eq!(result.len(), 1);
784 result.remove(0)
785 }
786
787 #[tokio::test]
788 async fn test_partition_sort() -> Result<()> {
789 let task_ctx = Arc::new(TaskContext::default());
790 let partitions = 4;
791 let csv = test::scan_partitioned(partitions);
792 let schema = csv.schema();
793
794 let sort: LexOrdering = [PhysicalSortExpr {
795 expr: col("i", &schema)?,
796 options: SortOptions {
797 descending: true,
798 nulls_first: true,
799 },
800 }]
801 .into();
802
803 let basic =
804 basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
805 let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
806
807 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
808 .unwrap()
809 .to_string();
810 let partition = arrow::util::pretty::pretty_format_batches(&[partition])
811 .unwrap()
812 .to_string();
813
814 assert_eq!(
815 basic, partition,
816 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
817 );
818
819 Ok(())
820 }
821
822 fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
824 let batches = sorted.num_rows().div_ceil(batch_size);
825
826 (0..batches)
828 .map(|batch_idx| {
829 let columns = (0..sorted.num_columns())
830 .map(|column_idx| {
831 let length =
832 batch_size.min(sorted.num_rows() - batch_idx * batch_size);
833
834 sorted
835 .column(column_idx)
836 .slice(batch_idx * batch_size, length)
837 })
838 .collect();
839
840 RecordBatch::try_new(sorted.schema(), columns).unwrap()
841 })
842 .collect()
843 }
844
845 async fn sorted_partitioned_input(
846 sort: LexOrdering,
847 sizes: &[usize],
848 context: Arc<TaskContext>,
849 ) -> Result<Arc<dyn ExecutionPlan>> {
850 let partitions = 4;
851 let csv = test::scan_partitioned(partitions);
852
853 let sorted = basic_sort(csv, sort, context).await;
854 let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
855
856 TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
857 }
858
859 #[tokio::test]
860 async fn test_partition_sort_streaming_input() -> Result<()> {
861 let task_ctx = Arc::new(TaskContext::default());
862 let schema = make_partition(11).schema();
863 let sort: LexOrdering = [PhysicalSortExpr {
864 expr: col("i", &schema)?,
865 options: Default::default(),
866 }]
867 .into();
868
869 let input =
870 sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
871 .await?;
872 let basic =
873 basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
874 let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
875
876 assert_eq!(basic.num_rows(), 1200);
877 assert_eq!(partition.num_rows(), 1200);
878
879 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
880 let partition =
881 arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
882
883 assert_eq!(basic, partition);
884
885 Ok(())
886 }
887
888 #[tokio::test]
889 async fn test_partition_sort_streaming_input_output() -> Result<()> {
890 let schema = make_partition(11).schema();
891 let sort: LexOrdering = [PhysicalSortExpr {
892 expr: col("i", &schema)?,
893 options: Default::default(),
894 }]
895 .into();
896
897 let task_ctx = Arc::new(TaskContext::default());
899 let input =
900 sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
901 .await?;
902 let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
903
904 let task_ctx = TaskContext::default()
906 .with_session_config(SessionConfig::new().with_batch_size(23));
907 let task_ctx = Arc::new(task_ctx);
908
909 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
910 let merged = collect(merge, task_ctx).await?;
911
912 assert_eq!(merged.len(), 53);
913 assert_eq!(basic.num_rows(), 1200);
914 assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
915
916 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
917 let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
918
919 assert_eq!(basic, partition);
920
921 Ok(())
922 }
923
924 #[tokio::test]
925 async fn test_nulls() {
926 let task_ctx = Arc::new(TaskContext::default());
927 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
928 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
929 None,
930 Some("a"),
931 Some("b"),
932 Some("d"),
933 Some("e"),
934 ]));
935 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
936 Some(8),
937 None,
938 Some(6),
939 None,
940 Some(4),
941 ]));
942 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
943
944 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
945 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
946 None,
947 Some("b"),
948 Some("g"),
949 Some("h"),
950 Some("i"),
951 ]));
952 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
953 Some(8),
954 None,
955 Some(5),
956 None,
957 Some(4),
958 ]));
959 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
960 let schema = b1.schema();
961
962 let sort = [
963 PhysicalSortExpr {
964 expr: col("b", &schema).unwrap(),
965 options: SortOptions {
966 descending: false,
967 nulls_first: true,
968 },
969 },
970 PhysicalSortExpr {
971 expr: col("c", &schema).unwrap(),
972 options: SortOptions {
973 descending: false,
974 nulls_first: false,
975 },
976 },
977 ]
978 .into();
979 let exec =
980 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
981 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
982
983 let collected = collect(merge, task_ctx).await.unwrap();
984 assert_eq!(collected.len(), 1);
985
986 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
987 +---+---+-------------------------------+
988 | a | b | c |
989 +---+---+-------------------------------+
990 | 1 | | 1970-01-01T00:00:00.000000008 |
991 | 1 | | 1970-01-01T00:00:00.000000008 |
992 | 2 | a | |
993 | 7 | b | 1970-01-01T00:00:00.000000006 |
994 | 2 | b | |
995 | 9 | d | |
996 | 3 | e | 1970-01-01T00:00:00.000000004 |
997 | 3 | g | 1970-01-01T00:00:00.000000005 |
998 | 4 | h | |
999 | 5 | i | 1970-01-01T00:00:00.000000004 |
1000 +---+---+-------------------------------+
1001 ");
1002 }
1003
1004 #[tokio::test]
1005 async fn test_sort_merge_single_partition_with_fetch() {
1006 let task_ctx = Arc::new(TaskContext::default());
1007 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1008 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1009 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1010 let schema = batch.schema();
1011
1012 let sort = [PhysicalSortExpr {
1013 expr: col("b", &schema).unwrap(),
1014 options: SortOptions {
1015 descending: false,
1016 nulls_first: true,
1017 },
1018 }]
1019 .into();
1020 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1021 let merge =
1022 Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1023
1024 let collected = collect(merge, task_ctx).await.unwrap();
1025 assert_eq!(collected.len(), 1);
1026
1027 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1028 +---+---+
1029 | a | b |
1030 +---+---+
1031 | 1 | a |
1032 | 2 | b |
1033 +---+---+
1034 ");
1035 }
1036
1037 #[tokio::test]
1038 async fn test_sort_merge_single_partition_without_fetch() {
1039 let task_ctx = Arc::new(TaskContext::default());
1040 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1041 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1042 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1043 let schema = batch.schema();
1044
1045 let sort = [PhysicalSortExpr {
1046 expr: col("b", &schema).unwrap(),
1047 options: SortOptions {
1048 descending: false,
1049 nulls_first: true,
1050 },
1051 }]
1052 .into();
1053 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1054 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1055
1056 let collected = collect(merge, task_ctx).await.unwrap();
1057 assert_eq!(collected.len(), 1);
1058
1059 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1060 +---+---+
1061 | a | b |
1062 +---+---+
1063 | 1 | a |
1064 | 2 | b |
1065 | 7 | c |
1066 | 9 | d |
1067 | 3 | e |
1068 +---+---+
1069 ");
1070 }
1071
1072 #[tokio::test]
1073 async fn test_async() -> Result<()> {
1074 let task_ctx = Arc::new(TaskContext::default());
1075 let schema = make_partition(11).schema();
1076 let sort: LexOrdering = [PhysicalSortExpr {
1077 expr: col("i", &schema).unwrap(),
1078 options: SortOptions::default(),
1079 }]
1080 .into();
1081
1082 let batches =
1083 sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1084 .await?;
1085
1086 let partition_count = batches.output_partitioning().partition_count();
1087 let mut streams = Vec::with_capacity(partition_count);
1088
1089 for partition in 0..partition_count {
1090 let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1091
1092 let sender = builder.tx();
1093
1094 let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1095 builder.spawn(async move {
1096 while let Some(batch) = stream.next().await {
1097 sender.send(batch).await.unwrap();
1098 tokio::time::sleep(Duration::from_millis(10)).await;
1100 }
1101
1102 Ok(())
1103 });
1104
1105 streams.push(builder.build());
1106 }
1107
1108 let metrics = ExecutionPlanMetricsSet::new();
1109 let reservation =
1110 MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1111
1112 let fetch = None;
1113 let merge_stream = StreamingMergeBuilder::new()
1114 .with_streams(streams)
1115 .with_schema(batches.schema())
1116 .with_expressions(&sort)
1117 .with_metrics(BaselineMetrics::new(&metrics, 0))
1118 .with_batch_size(task_ctx.session_config().batch_size())
1119 .with_fetch(fetch)
1120 .with_reservation(reservation)
1121 .build()?;
1122
1123 let mut merged = common::collect(merge_stream).await.unwrap();
1124
1125 assert_eq!(merged.len(), 1);
1126 let merged = merged.remove(0);
1127 let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1128
1129 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1130 .unwrap()
1131 .to_string();
1132 let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1133 .unwrap()
1134 .to_string();
1135
1136 assert_eq!(
1137 basic, partition,
1138 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1139 );
1140
1141 Ok(())
1142 }
1143
1144 #[tokio::test]
1145 async fn test_merge_metrics() {
1146 let task_ctx = Arc::new(TaskContext::default());
1147 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1148 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1149 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1150
1151 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1152 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1153 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1154
1155 let schema = b1.schema();
1156 let sort = [PhysicalSortExpr {
1157 expr: col("b", &schema).unwrap(),
1158 options: Default::default(),
1159 }]
1160 .into();
1161 let exec =
1162 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1163 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1164
1165 let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1166 .await
1167 .unwrap();
1168 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1169 +----+---+
1170 | a | b |
1171 +----+---+
1172 | 1 | a |
1173 | 10 | b |
1174 | 2 | c |
1175 | 20 | d |
1176 +----+---+
1177 ");
1178
1179 let metrics = merge.metrics().unwrap();
1181
1182 assert_eq!(metrics.output_rows().unwrap(), 4);
1183 assert!(metrics.elapsed_compute().unwrap() > 0);
1184
1185 let mut saw_start = false;
1186 let mut saw_end = false;
1187 metrics.iter().for_each(|m| match m.value() {
1188 MetricValue::StartTimestamp(ts) => {
1189 saw_start = true;
1190 assert!(nanos_from_timestamp(ts) > 0);
1191 }
1192 MetricValue::EndTimestamp(ts) => {
1193 saw_end = true;
1194 assert!(nanos_from_timestamp(ts) > 0);
1195 }
1196 _ => {}
1197 });
1198
1199 assert!(saw_start);
1200 assert!(saw_end);
1201 }
1202
1203 fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1204 ts.value().unwrap().timestamp_nanos_opt().unwrap()
1205 }
1206
1207 #[tokio::test]
1208 async fn test_drop_cancel() -> Result<()> {
1209 let task_ctx = Arc::new(TaskContext::default());
1210 let schema =
1211 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1212
1213 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1214 let refs = blocking_exec.refs();
1215 let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1216 [PhysicalSortExpr {
1217 expr: col("a", &schema)?,
1218 options: SortOptions::default(),
1219 }]
1220 .into(),
1221 blocking_exec,
1222 ));
1223
1224 let fut = collect(sort_preserving_merge_exec, task_ctx);
1225 let mut fut = fut.boxed();
1226
1227 assert_is_pending(&mut fut);
1228 drop(fut);
1229 assert_strong_count_converges_to_zero(refs).await;
1230
1231 Ok(())
1232 }
1233
1234 #[tokio::test]
1235 async fn test_stable_sort() {
1236 let task_ctx = Arc::new(TaskContext::default());
1237
1238 let partitions: Vec<Vec<RecordBatch>> = (0..10)
1246 .map(|batch_number| {
1247 let batch_number: Int32Array =
1248 vec![Some(batch_number), Some(batch_number)]
1249 .into_iter()
1250 .collect();
1251 let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1252
1253 let batch = RecordBatch::try_from_iter(vec![
1254 ("batch_number", Arc::new(batch_number) as ArrayRef),
1255 ("value", Arc::new(value) as ArrayRef),
1256 ])
1257 .unwrap();
1258
1259 vec![batch]
1260 })
1261 .collect();
1262
1263 let schema = partitions[0][0].schema();
1264
1265 let sort = [PhysicalSortExpr {
1266 expr: col("value", &schema).unwrap(),
1267 options: SortOptions {
1268 descending: false,
1269 nulls_first: true,
1270 },
1271 }]
1272 .into();
1273
1274 let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1275 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1276
1277 let collected = collect(merge, task_ctx).await.unwrap();
1278 assert_eq!(collected.len(), 1);
1279
1280 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1284 +--------------+-------+
1285 | batch_number | value |
1286 +--------------+-------+
1287 | 0 | A |
1288 | 1 | A |
1289 | 2 | A |
1290 | 3 | A |
1291 | 4 | A |
1292 | 5 | A |
1293 | 6 | A |
1294 | 7 | A |
1295 | 8 | A |
1296 | 9 | A |
1297 | 0 | B |
1298 | 1 | B |
1299 | 2 | B |
1300 | 3 | B |
1301 | 4 | B |
1302 | 5 | B |
1303 | 6 | B |
1304 | 7 | B |
1305 | 8 | B |
1306 | 9 | B |
1307 +--------------+-------+
1308 ");
1309 }
1310
1311 #[derive(Debug)]
1312 struct CongestionState {
1313 wakers: Vec<Waker>,
1314 unpolled_partitions: HashSet<usize>,
1315 }
1316
1317 #[derive(Debug)]
1318 struct Congestion {
1319 congestion_state: Mutex<CongestionState>,
1320 }
1321
1322 impl Congestion {
1323 fn new(partition_count: usize) -> Self {
1324 Congestion {
1325 congestion_state: Mutex::new(CongestionState {
1326 wakers: vec![],
1327 unpolled_partitions: (0usize..partition_count).collect(),
1328 }),
1329 }
1330 }
1331
1332 fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1333 let mut state = self.congestion_state.lock().unwrap();
1334
1335 state.unpolled_partitions.remove(&partition);
1336
1337 if state.unpolled_partitions.is_empty() {
1338 state.wakers.iter().for_each(|w| w.wake_by_ref());
1339 state.wakers.clear();
1340 Poll::Ready(())
1341 } else {
1342 state.wakers.push(cx.waker().clone());
1343 Poll::Pending
1344 }
1345 }
1346 }
1347
1348 #[derive(Debug, Clone)]
1351 struct CongestedExec {
1352 schema: Schema,
1353 cache: PlanProperties,
1354 congestion: Arc<Congestion>,
1355 }
1356
1357 impl CongestedExec {
1358 fn compute_properties(schema: SchemaRef) -> PlanProperties {
1359 let columns = schema
1360 .fields
1361 .iter()
1362 .enumerate()
1363 .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1364 .collect::<Vec<_>>();
1365 let mut eq_properties = EquivalenceProperties::new(schema);
1366 eq_properties.add_ordering(
1367 columns
1368 .iter()
1369 .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1370 );
1371 PlanProperties::new(
1372 eq_properties,
1373 Partitioning::Hash(columns, 3),
1374 EmissionType::Incremental,
1375 Boundedness::Unbounded {
1376 requires_infinite_memory: false,
1377 },
1378 )
1379 }
1380 }
1381
1382 impl ExecutionPlan for CongestedExec {
1383 fn name(&self) -> &'static str {
1384 Self::static_name()
1385 }
1386 fn as_any(&self) -> &dyn Any {
1387 self
1388 }
1389 fn properties(&self) -> &PlanProperties {
1390 &self.cache
1391 }
1392 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1393 vec![]
1394 }
1395 fn with_new_children(
1396 self: Arc<Self>,
1397 _: Vec<Arc<dyn ExecutionPlan>>,
1398 ) -> Result<Arc<dyn ExecutionPlan>> {
1399 Ok(self)
1400 }
1401 fn execute(
1402 &self,
1403 partition: usize,
1404 _context: Arc<TaskContext>,
1405 ) -> Result<SendableRecordBatchStream> {
1406 Ok(Box::pin(CongestedStream {
1407 schema: Arc::new(self.schema.clone()),
1408 none_polled_once: false,
1409 congestion: Arc::clone(&self.congestion),
1410 partition,
1411 }))
1412 }
1413 }
1414
1415 impl DisplayAs for CongestedExec {
1416 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1417 match t {
1418 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1419 write!(f, "CongestedExec",).unwrap()
1420 }
1421 DisplayFormatType::TreeRender => {
1422 write!(f, "").unwrap()
1424 }
1425 }
1426 Ok(())
1427 }
1428 }
1429
1430 #[derive(Debug)]
1433 pub struct CongestedStream {
1434 schema: SchemaRef,
1435 none_polled_once: bool,
1436 congestion: Arc<Congestion>,
1437 partition: usize,
1438 }
1439
1440 impl Stream for CongestedStream {
1441 type Item = Result<RecordBatch>;
1442 fn poll_next(
1443 mut self: Pin<&mut Self>,
1444 cx: &mut Context<'_>,
1445 ) -> Poll<Option<Self::Item>> {
1446 match self.partition {
1447 0 => {
1448 let _ = self.congestion.check_congested(self.partition, cx);
1449 if self.none_polled_once {
1450 panic!("Exhausted stream is polled more than once")
1451 } else {
1452 self.none_polled_once = true;
1453 Poll::Ready(None)
1454 }
1455 }
1456 _ => {
1457 ready!(self.congestion.check_congested(self.partition, cx));
1458 Poll::Ready(None)
1459 }
1460 }
1461 }
1462 }
1463
1464 impl RecordBatchStream for CongestedStream {
1465 fn schema(&self) -> SchemaRef {
1466 Arc::clone(&self.schema)
1467 }
1468 }
1469
1470 #[tokio::test]
1471 async fn test_spm_congestion() -> Result<()> {
1472 let task_ctx = Arc::new(TaskContext::default());
1473 let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1474 let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1475 let &partition_count = match properties.output_partitioning() {
1476 Partitioning::RoundRobinBatch(partitions) => partitions,
1477 Partitioning::Hash(_, partitions) => partitions,
1478 Partitioning::UnknownPartitioning(partitions) => partitions,
1479 };
1480 let source = CongestedExec {
1481 schema: schema.clone(),
1482 cache: properties,
1483 congestion: Arc::new(Congestion::new(partition_count)),
1484 };
1485 let spm = SortPreservingMergeExec::new(
1486 [PhysicalSortExpr::new_default(Arc::new(Column::new(
1487 "c1", 0,
1488 )))]
1489 .into(),
1490 Arc::new(source),
1491 );
1492 let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1493
1494 let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1495 match result {
1496 Ok(Ok(Ok(_batches))) => Ok(()),
1497 Ok(Ok(Err(e))) => Err(e),
1498 Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1499 Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1500 }
1501 }
1502}