1use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{make_with_child, update_ordering, ProjectionExec};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31};
32
33use datafusion_common::{internal_err, Result};
34use datafusion_execution::memory_pool::MemoryConsumer;
35use datafusion_execution::TaskContext;
36use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
37
38use crate::execution_plan::{EvaluationType, SchedulingType};
39use log::{debug, trace};
40
41#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87 input: Arc<dyn ExecutionPlan>,
89 expr: LexOrdering,
91 metrics: ExecutionPlanMetricsSet,
93 fetch: Option<usize>,
95 cache: PlanProperties,
97 enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104 pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106 let cache = Self::compute_properties(&input, expr.clone());
107 Self {
108 input,
109 expr,
110 metrics: ExecutionPlanMetricsSet::new(),
111 fetch: None,
112 cache,
113 enable_round_robin_repartition: true,
114 }
115 }
116
117 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119 self.fetch = fetch;
120 self
121 }
122
123 pub fn with_round_robin_repartition(
133 mut self,
134 enable_round_robin_repartition: bool,
135 ) -> Self {
136 self.enable_round_robin_repartition = enable_round_robin_repartition;
137 self
138 }
139
140 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142 &self.input
143 }
144
145 pub fn expr(&self) -> &LexOrdering {
147 &self.expr
148 }
149
150 pub fn fetch(&self) -> Option<usize> {
152 self.fetch
153 }
154
155 fn compute_properties(
158 input: &Arc<dyn ExecutionPlan>,
159 ordering: LexOrdering,
160 ) -> PlanProperties {
161 let input_partitions = input.output_partitioning().partition_count();
162 let (drive, scheduling) = if input_partitions > 1 {
163 (EvaluationType::Eager, SchedulingType::Cooperative)
164 } else {
165 (
166 input.properties().evaluation_type,
167 input.properties().scheduling_type,
168 )
169 };
170
171 let mut eq_properties = input.equivalence_properties().clone();
172 eq_properties.clear_per_partition_constants();
173 eq_properties.add_ordering(ordering);
174 PlanProperties::new(
175 eq_properties, Partitioning::UnknownPartitioning(1), input.pipeline_behavior(), input.boundedness(), )
180 .with_evaluation_type(drive)
181 .with_scheduling_type(scheduling)
182 }
183}
184
185impl DisplayAs for SortPreservingMergeExec {
186 fn fmt_as(
187 &self,
188 t: DisplayFormatType,
189 f: &mut std::fmt::Formatter,
190 ) -> std::fmt::Result {
191 match t {
192 DisplayFormatType::Default | DisplayFormatType::Verbose => {
193 write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
194 if let Some(fetch) = self.fetch {
195 write!(f, ", fetch={fetch}")?;
196 };
197
198 Ok(())
199 }
200 DisplayFormatType::TreeRender => {
201 if let Some(fetch) = self.fetch {
202 writeln!(f, "limit={fetch}")?;
203 };
204
205 for (i, e) in self.expr().iter().enumerate() {
206 e.fmt_sql(f)?;
207 if i != self.expr().len() - 1 {
208 write!(f, ", ")?;
209 }
210 }
211
212 Ok(())
213 }
214 }
215 }
216}
217
218impl ExecutionPlan for SortPreservingMergeExec {
219 fn name(&self) -> &'static str {
220 "SortPreservingMergeExec"
221 }
222
223 fn as_any(&self) -> &dyn Any {
225 self
226 }
227
228 fn properties(&self) -> &PlanProperties {
229 &self.cache
230 }
231
232 fn fetch(&self) -> Option<usize> {
233 self.fetch
234 }
235
236 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
238 Some(Arc::new(Self {
239 input: Arc::clone(&self.input),
240 expr: self.expr.clone(),
241 metrics: self.metrics.clone(),
242 fetch: limit,
243 cache: self.cache.clone(),
244 enable_round_robin_repartition: true,
245 }))
246 }
247
248 fn required_input_distribution(&self) -> Vec<Distribution> {
249 vec![Distribution::UnspecifiedDistribution]
250 }
251
252 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
253 vec![false]
254 }
255
256 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
257 vec![Some(OrderingRequirements::from(self.expr.clone()))]
258 }
259
260 fn maintains_input_order(&self) -> Vec<bool> {
261 vec![true]
262 }
263
264 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
265 vec![&self.input]
266 }
267
268 fn with_new_children(
269 self: Arc<Self>,
270 children: Vec<Arc<dyn ExecutionPlan>>,
271 ) -> Result<Arc<dyn ExecutionPlan>> {
272 Ok(Arc::new(
273 SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
274 .with_fetch(self.fetch),
275 ))
276 }
277
278 fn execute(
279 &self,
280 partition: usize,
281 context: Arc<TaskContext>,
282 ) -> Result<SendableRecordBatchStream> {
283 trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
284 if 0 != partition {
285 return internal_err!(
286 "SortPreservingMergeExec invalid partition {partition}"
287 );
288 }
289
290 let input_partitions = self.input.output_partitioning().partition_count();
291 trace!(
292 "Number of input partitions of SortPreservingMergeExec::execute: {input_partitions}"
293 );
294 let schema = self.schema();
295
296 let reservation =
297 MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
298 .register(&context.runtime_env().memory_pool);
299
300 match input_partitions {
301 0 => internal_err!(
302 "SortPreservingMergeExec requires at least one input partition"
303 ),
304 1 => match self.fetch {
305 Some(fetch) => {
306 let stream = self.input.execute(0, context)?;
307 debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}");
308 Ok(Box::pin(LimitStream::new(
309 stream,
310 0,
311 Some(fetch),
312 BaselineMetrics::new(&self.metrics, partition),
313 )))
314 }
315 None => {
316 let stream = self.input.execute(0, context);
317 debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch");
318 stream
319 }
320 },
321 _ => {
322 let receivers = (0..input_partitions)
323 .map(|partition| {
324 let stream =
325 self.input.execute(partition, Arc::clone(&context))?;
326 Ok(spawn_buffered(stream, 1))
327 })
328 .collect::<Result<_>>()?;
329
330 debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");
331
332 let result = StreamingMergeBuilder::new()
333 .with_streams(receivers)
334 .with_schema(schema)
335 .with_expressions(&self.expr)
336 .with_metrics(BaselineMetrics::new(&self.metrics, partition))
337 .with_batch_size(context.session_config().batch_size())
338 .with_fetch(self.fetch)
339 .with_reservation(reservation)
340 .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
341 .build()?;
342
343 debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
344
345 Ok(result)
346 }
347 }
348 }
349
350 fn metrics(&self) -> Option<MetricsSet> {
351 Some(self.metrics.clone_inner())
352 }
353
354 fn statistics(&self) -> Result<Statistics> {
355 self.input.partition_statistics(None)
356 }
357
358 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
359 self.input.partition_statistics(None)
360 }
361
362 fn supports_limit_pushdown(&self) -> bool {
363 true
364 }
365
366 fn try_swapping_with_projection(
370 &self,
371 projection: &ProjectionExec,
372 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
373 if projection.expr().len() >= projection.input().schema().fields().len() {
375 return Ok(None);
376 }
377
378 let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
379 else {
380 return Ok(None);
381 };
382
383 Ok(Some(Arc::new(
384 SortPreservingMergeExec::new(
385 updated_exprs,
386 make_with_child(projection, self.input())?,
387 )
388 .with_fetch(self.fetch()),
389 )))
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use std::collections::HashSet;
396 use std::fmt::Formatter;
397 use std::pin::Pin;
398 use std::sync::Mutex;
399 use std::task::{ready, Context, Poll, Waker};
400 use std::time::Duration;
401
402 use super::*;
403 use crate::coalesce_batches::CoalesceBatchesExec;
404 use crate::coalesce_partitions::CoalescePartitionsExec;
405 use crate::execution_plan::{Boundedness, EmissionType};
406 use crate::expressions::col;
407 use crate::metrics::{MetricValue, Timestamp};
408 use crate::repartition::RepartitionExec;
409 use crate::sorts::sort::SortExec;
410 use crate::stream::RecordBatchReceiverStream;
411 use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
412 use crate::test::TestMemoryExec;
413 use crate::test::{self, assert_is_pending, make_partition};
414 use crate::{collect, common};
415
416 use arrow::array::{
417 ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
418 TimestampNanosecondArray,
419 };
420 use arrow::compute::SortOptions;
421 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
422 use datafusion_common::test_util::batches_to_string;
423 use datafusion_common::{assert_batches_eq, exec_err};
424 use datafusion_common_runtime::SpawnedTask;
425 use datafusion_execution::config::SessionConfig;
426 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
427 use datafusion_execution::RecordBatchStream;
428 use datafusion_physical_expr::expressions::Column;
429 use datafusion_physical_expr::EquivalenceProperties;
430 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
431 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
432
433 use futures::{FutureExt, Stream, StreamExt};
434 use insta::assert_snapshot;
435 use tokio::time::timeout;
436
437 fn generate_task_ctx_for_round_robin_tie_breaker() -> Result<Arc<TaskContext>> {
440 let runtime = RuntimeEnvBuilder::new()
441 .with_memory_limit(20_000_000, 1.0)
442 .build_arc()?;
443 let config = SessionConfig::new();
444 let task_ctx = TaskContext::default()
445 .with_runtime(runtime)
446 .with_session_config(config);
447 Ok(Arc::new(task_ctx))
448 }
449 fn generate_spm_for_round_robin_tie_breaker(
452 enable_round_robin_repartition: bool,
453 ) -> Result<Arc<SortPreservingMergeExec>> {
454 let target_batch_size = 12500;
455 let row_size = 12500;
456 let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
457 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
458 let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
459 let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
460
461 let rbs = (0..1024).map(|_| rb.clone()).collect::<Vec<_>>();
462
463 let schema = rb.schema();
464 let sort = [
465 PhysicalSortExpr {
466 expr: col("b", &schema)?,
467 options: Default::default(),
468 },
469 PhysicalSortExpr {
470 expr: col("c", &schema)?,
471 options: Default::default(),
472 },
473 ]
474 .into();
475
476 let repartition_exec = RepartitionExec::try_new(
477 TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
478 Partitioning::RoundRobinBatch(2),
479 )?;
480 let coalesce_batches_exec =
481 CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size);
482 let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec))
483 .with_round_robin_repartition(enable_round_robin_repartition);
484 Ok(Arc::new(spm))
485 }
486
487 #[tokio::test(flavor = "multi_thread")]
493 async fn test_round_robin_tie_breaker_success() -> Result<()> {
494 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
495 let spm = generate_spm_for_round_robin_tie_breaker(true)?;
496 let _collected = collect(spm, task_ctx).await?;
497 Ok(())
498 }
499
500 #[tokio::test(flavor = "multi_thread")]
506 async fn test_round_robin_tie_breaker_fail() -> Result<()> {
507 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
508 let spm = generate_spm_for_round_robin_tie_breaker(false)?;
509 let _err = collect(spm, task_ctx).await.unwrap_err();
510 Ok(())
511 }
512
513 #[tokio::test]
514 async fn test_merge_interleave() {
515 let task_ctx = Arc::new(TaskContext::default());
516 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
517 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
518 Some("a"),
519 Some("c"),
520 Some("e"),
521 Some("g"),
522 Some("j"),
523 ]));
524 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
525 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
526
527 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
528 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
529 Some("b"),
530 Some("d"),
531 Some("f"),
532 Some("h"),
533 Some("j"),
534 ]));
535 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
536 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
537
538 _test_merge(
539 &[vec![b1], vec![b2]],
540 &[
541 "+----+---+-------------------------------+",
542 "| a | b | c |",
543 "+----+---+-------------------------------+",
544 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
545 "| 10 | b | 1970-01-01T00:00:00.000000004 |",
546 "| 2 | c | 1970-01-01T00:00:00.000000007 |",
547 "| 20 | d | 1970-01-01T00:00:00.000000006 |",
548 "| 7 | e | 1970-01-01T00:00:00.000000006 |",
549 "| 70 | f | 1970-01-01T00:00:00.000000002 |",
550 "| 9 | g | 1970-01-01T00:00:00.000000005 |",
551 "| 90 | h | 1970-01-01T00:00:00.000000002 |",
552 "| 30 | j | 1970-01-01T00:00:00.000000006 |", "| 3 | j | 1970-01-01T00:00:00.000000008 |",
554 "+----+---+-------------------------------+",
555 ],
556 task_ctx,
557 )
558 .await;
559 }
560
561 #[tokio::test]
562 async fn test_merge_some_overlap() {
563 let task_ctx = Arc::new(TaskContext::default());
564 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
565 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
566 Some("a"),
567 Some("b"),
568 Some("c"),
569 Some("d"),
570 Some("e"),
571 ]));
572 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
573 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
574
575 let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
576 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
577 Some("c"),
578 Some("d"),
579 Some("e"),
580 Some("f"),
581 Some("g"),
582 ]));
583 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
584 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
585
586 _test_merge(
587 &[vec![b1], vec![b2]],
588 &[
589 "+-----+---+-------------------------------+",
590 "| a | b | c |",
591 "+-----+---+-------------------------------+",
592 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
593 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
594 "| 70 | c | 1970-01-01T00:00:00.000000004 |",
595 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
596 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
597 "| 90 | d | 1970-01-01T00:00:00.000000006 |",
598 "| 30 | e | 1970-01-01T00:00:00.000000002 |",
599 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
600 "| 100 | f | 1970-01-01T00:00:00.000000002 |",
601 "| 110 | g | 1970-01-01T00:00:00.000000006 |",
602 "+-----+---+-------------------------------+",
603 ],
604 task_ctx,
605 )
606 .await;
607 }
608
609 #[tokio::test]
610 async fn test_merge_no_overlap() {
611 let task_ctx = Arc::new(TaskContext::default());
612 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
613 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
614 Some("a"),
615 Some("b"),
616 Some("c"),
617 Some("d"),
618 Some("e"),
619 ]));
620 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
621 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
622
623 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
624 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
625 Some("f"),
626 Some("g"),
627 Some("h"),
628 Some("i"),
629 Some("j"),
630 ]));
631 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
632 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
633
634 _test_merge(
635 &[vec![b1], vec![b2]],
636 &[
637 "+----+---+-------------------------------+",
638 "| a | b | c |",
639 "+----+---+-------------------------------+",
640 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
641 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
642 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
643 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
644 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
645 "| 10 | f | 1970-01-01T00:00:00.000000004 |",
646 "| 20 | g | 1970-01-01T00:00:00.000000006 |",
647 "| 70 | h | 1970-01-01T00:00:00.000000002 |",
648 "| 90 | i | 1970-01-01T00:00:00.000000002 |",
649 "| 30 | j | 1970-01-01T00:00:00.000000006 |",
650 "+----+---+-------------------------------+",
651 ],
652 task_ctx,
653 )
654 .await;
655 }
656
657 #[tokio::test]
658 async fn test_merge_three_partitions() {
659 let task_ctx = Arc::new(TaskContext::default());
660 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
661 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
662 Some("a"),
663 Some("b"),
664 Some("c"),
665 Some("d"),
666 Some("f"),
667 ]));
668 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
669 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
670
671 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
672 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
673 Some("e"),
674 Some("g"),
675 Some("h"),
676 Some("i"),
677 Some("j"),
678 ]));
679 let c: ArrayRef =
680 Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
681 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
682
683 let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
684 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
685 Some("f"),
686 Some("g"),
687 Some("h"),
688 Some("i"),
689 Some("j"),
690 ]));
691 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
692 let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
693
694 _test_merge(
695 &[vec![b1], vec![b2], vec![b3]],
696 &[
697 "+-----+---+-------------------------------+",
698 "| a | b | c |",
699 "+-----+---+-------------------------------+",
700 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
701 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
702 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
703 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
704 "| 10 | e | 1970-01-01T00:00:00.000000040 |",
705 "| 100 | f | 1970-01-01T00:00:00.000000004 |",
706 "| 3 | f | 1970-01-01T00:00:00.000000008 |",
707 "| 200 | g | 1970-01-01T00:00:00.000000006 |",
708 "| 20 | g | 1970-01-01T00:00:00.000000060 |",
709 "| 700 | h | 1970-01-01T00:00:00.000000002 |",
710 "| 70 | h | 1970-01-01T00:00:00.000000020 |",
711 "| 900 | i | 1970-01-01T00:00:00.000000002 |",
712 "| 90 | i | 1970-01-01T00:00:00.000000020 |",
713 "| 300 | j | 1970-01-01T00:00:00.000000006 |",
714 "| 30 | j | 1970-01-01T00:00:00.000000060 |",
715 "+-----+---+-------------------------------+",
716 ],
717 task_ctx,
718 )
719 .await;
720 }
721
722 async fn _test_merge(
723 partitions: &[Vec<RecordBatch>],
724 exp: &[&str],
725 context: Arc<TaskContext>,
726 ) {
727 let schema = partitions[0][0].schema();
728 let sort = [
729 PhysicalSortExpr {
730 expr: col("b", &schema).unwrap(),
731 options: Default::default(),
732 },
733 PhysicalSortExpr {
734 expr: col("c", &schema).unwrap(),
735 options: Default::default(),
736 },
737 ]
738 .into();
739 let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
740 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
741
742 let collected = collect(merge, context).await.unwrap();
743 assert_batches_eq!(exp, collected.as_slice());
744 }
745
746 async fn sorted_merge(
747 input: Arc<dyn ExecutionPlan>,
748 sort: LexOrdering,
749 context: Arc<TaskContext>,
750 ) -> RecordBatch {
751 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
752 let mut result = collect(merge, context).await.unwrap();
753 assert_eq!(result.len(), 1);
754 result.remove(0)
755 }
756
757 async fn partition_sort(
758 input: Arc<dyn ExecutionPlan>,
759 sort: LexOrdering,
760 context: Arc<TaskContext>,
761 ) -> RecordBatch {
762 let sort_exec =
763 Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
764 sorted_merge(sort_exec, sort, context).await
765 }
766
767 async fn basic_sort(
768 src: Arc<dyn ExecutionPlan>,
769 sort: LexOrdering,
770 context: Arc<TaskContext>,
771 ) -> RecordBatch {
772 let merge = Arc::new(CoalescePartitionsExec::new(src));
773 let sort_exec = Arc::new(SortExec::new(sort, merge));
774 let mut result = collect(sort_exec, context).await.unwrap();
775 assert_eq!(result.len(), 1);
776 result.remove(0)
777 }
778
779 #[tokio::test]
780 async fn test_partition_sort() -> Result<()> {
781 let task_ctx = Arc::new(TaskContext::default());
782 let partitions = 4;
783 let csv = test::scan_partitioned(partitions);
784 let schema = csv.schema();
785
786 let sort: LexOrdering = [PhysicalSortExpr {
787 expr: col("i", &schema)?,
788 options: SortOptions {
789 descending: true,
790 nulls_first: true,
791 },
792 }]
793 .into();
794
795 let basic =
796 basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
797 let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
798
799 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
800 .unwrap()
801 .to_string();
802 let partition = arrow::util::pretty::pretty_format_batches(&[partition])
803 .unwrap()
804 .to_string();
805
806 assert_eq!(
807 basic, partition,
808 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
809 );
810
811 Ok(())
812 }
813
814 fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
816 let batches = sorted.num_rows().div_ceil(batch_size);
817
818 (0..batches)
820 .map(|batch_idx| {
821 let columns = (0..sorted.num_columns())
822 .map(|column_idx| {
823 let length =
824 batch_size.min(sorted.num_rows() - batch_idx * batch_size);
825
826 sorted
827 .column(column_idx)
828 .slice(batch_idx * batch_size, length)
829 })
830 .collect();
831
832 RecordBatch::try_new(sorted.schema(), columns).unwrap()
833 })
834 .collect()
835 }
836
837 async fn sorted_partitioned_input(
838 sort: LexOrdering,
839 sizes: &[usize],
840 context: Arc<TaskContext>,
841 ) -> Result<Arc<dyn ExecutionPlan>> {
842 let partitions = 4;
843 let csv = test::scan_partitioned(partitions);
844
845 let sorted = basic_sort(csv, sort, context).await;
846 let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
847
848 TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
849 }
850
851 #[tokio::test]
852 async fn test_partition_sort_streaming_input() -> Result<()> {
853 let task_ctx = Arc::new(TaskContext::default());
854 let schema = make_partition(11).schema();
855 let sort: LexOrdering = [PhysicalSortExpr {
856 expr: col("i", &schema)?,
857 options: Default::default(),
858 }]
859 .into();
860
861 let input =
862 sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
863 .await?;
864 let basic =
865 basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
866 let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
867
868 assert_eq!(basic.num_rows(), 1200);
869 assert_eq!(partition.num_rows(), 1200);
870
871 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
872 let partition =
873 arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
874
875 assert_eq!(basic, partition);
876
877 Ok(())
878 }
879
880 #[tokio::test]
881 async fn test_partition_sort_streaming_input_output() -> Result<()> {
882 let schema = make_partition(11).schema();
883 let sort: LexOrdering = [PhysicalSortExpr {
884 expr: col("i", &schema)?,
885 options: Default::default(),
886 }]
887 .into();
888
889 let task_ctx = Arc::new(TaskContext::default());
891 let input =
892 sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
893 .await?;
894 let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
895
896 let task_ctx = TaskContext::default()
898 .with_session_config(SessionConfig::new().with_batch_size(23));
899 let task_ctx = Arc::new(task_ctx);
900
901 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
902 let merged = collect(merge, task_ctx).await?;
903
904 assert_eq!(merged.len(), 53);
905 assert_eq!(basic.num_rows(), 1200);
906 assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
907
908 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
909 let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
910
911 assert_eq!(basic, partition);
912
913 Ok(())
914 }
915
916 #[tokio::test]
917 async fn test_nulls() {
918 let task_ctx = Arc::new(TaskContext::default());
919 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
920 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
921 None,
922 Some("a"),
923 Some("b"),
924 Some("d"),
925 Some("e"),
926 ]));
927 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
928 Some(8),
929 None,
930 Some(6),
931 None,
932 Some(4),
933 ]));
934 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
935
936 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
937 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
938 None,
939 Some("b"),
940 Some("g"),
941 Some("h"),
942 Some("i"),
943 ]));
944 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
945 Some(8),
946 None,
947 Some(5),
948 None,
949 Some(4),
950 ]));
951 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
952 let schema = b1.schema();
953
954 let sort = [
955 PhysicalSortExpr {
956 expr: col("b", &schema).unwrap(),
957 options: SortOptions {
958 descending: false,
959 nulls_first: true,
960 },
961 },
962 PhysicalSortExpr {
963 expr: col("c", &schema).unwrap(),
964 options: SortOptions {
965 descending: false,
966 nulls_first: false,
967 },
968 },
969 ]
970 .into();
971 let exec =
972 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
973 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
974
975 let collected = collect(merge, task_ctx).await.unwrap();
976 assert_eq!(collected.len(), 1);
977
978 assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
979 +---+---+-------------------------------+
980 | a | b | c |
981 +---+---+-------------------------------+
982 | 1 | | 1970-01-01T00:00:00.000000008 |
983 | 1 | | 1970-01-01T00:00:00.000000008 |
984 | 2 | a | |
985 | 7 | b | 1970-01-01T00:00:00.000000006 |
986 | 2 | b | |
987 | 9 | d | |
988 | 3 | e | 1970-01-01T00:00:00.000000004 |
989 | 3 | g | 1970-01-01T00:00:00.000000005 |
990 | 4 | h | |
991 | 5 | i | 1970-01-01T00:00:00.000000004 |
992 +---+---+-------------------------------+
993 "#);
994 }
995
996 #[tokio::test]
997 async fn test_sort_merge_single_partition_with_fetch() {
998 let task_ctx = Arc::new(TaskContext::default());
999 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1000 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1001 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1002 let schema = batch.schema();
1003
1004 let sort = [PhysicalSortExpr {
1005 expr: col("b", &schema).unwrap(),
1006 options: SortOptions {
1007 descending: false,
1008 nulls_first: true,
1009 },
1010 }]
1011 .into();
1012 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1013 let merge =
1014 Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1015
1016 let collected = collect(merge, task_ctx).await.unwrap();
1017 assert_eq!(collected.len(), 1);
1018
1019 assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1020 +---+---+
1021 | a | b |
1022 +---+---+
1023 | 1 | a |
1024 | 2 | b |
1025 +---+---+
1026 "#);
1027 }
1028
1029 #[tokio::test]
1030 async fn test_sort_merge_single_partition_without_fetch() {
1031 let task_ctx = Arc::new(TaskContext::default());
1032 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1033 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1034 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1035 let schema = batch.schema();
1036
1037 let sort = [PhysicalSortExpr {
1038 expr: col("b", &schema).unwrap(),
1039 options: SortOptions {
1040 descending: false,
1041 nulls_first: true,
1042 },
1043 }]
1044 .into();
1045 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1046 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1047
1048 let collected = collect(merge, task_ctx).await.unwrap();
1049 assert_eq!(collected.len(), 1);
1050
1051 assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1052 +---+---+
1053 | a | b |
1054 +---+---+
1055 | 1 | a |
1056 | 2 | b |
1057 | 7 | c |
1058 | 9 | d |
1059 | 3 | e |
1060 +---+---+
1061 "#);
1062 }
1063
1064 #[tokio::test]
1065 async fn test_async() -> Result<()> {
1066 let task_ctx = Arc::new(TaskContext::default());
1067 let schema = make_partition(11).schema();
1068 let sort: LexOrdering = [PhysicalSortExpr {
1069 expr: col("i", &schema).unwrap(),
1070 options: SortOptions::default(),
1071 }]
1072 .into();
1073
1074 let batches =
1075 sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1076 .await?;
1077
1078 let partition_count = batches.output_partitioning().partition_count();
1079 let mut streams = Vec::with_capacity(partition_count);
1080
1081 for partition in 0..partition_count {
1082 let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1083
1084 let sender = builder.tx();
1085
1086 let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1087 builder.spawn(async move {
1088 while let Some(batch) = stream.next().await {
1089 sender.send(batch).await.unwrap();
1090 tokio::time::sleep(Duration::from_millis(10)).await;
1092 }
1093
1094 Ok(())
1095 });
1096
1097 streams.push(builder.build());
1098 }
1099
1100 let metrics = ExecutionPlanMetricsSet::new();
1101 let reservation =
1102 MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1103
1104 let fetch = None;
1105 let merge_stream = StreamingMergeBuilder::new()
1106 .with_streams(streams)
1107 .with_schema(batches.schema())
1108 .with_expressions(&sort)
1109 .with_metrics(BaselineMetrics::new(&metrics, 0))
1110 .with_batch_size(task_ctx.session_config().batch_size())
1111 .with_fetch(fetch)
1112 .with_reservation(reservation)
1113 .build()?;
1114
1115 let mut merged = common::collect(merge_stream).await.unwrap();
1116
1117 assert_eq!(merged.len(), 1);
1118 let merged = merged.remove(0);
1119 let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1120
1121 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1122 .unwrap()
1123 .to_string();
1124 let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1125 .unwrap()
1126 .to_string();
1127
1128 assert_eq!(
1129 basic, partition,
1130 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1131 );
1132
1133 Ok(())
1134 }
1135
1136 #[tokio::test]
1137 async fn test_merge_metrics() {
1138 let task_ctx = Arc::new(TaskContext::default());
1139 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1140 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1141 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1142
1143 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1144 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1145 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1146
1147 let schema = b1.schema();
1148 let sort = [PhysicalSortExpr {
1149 expr: col("b", &schema).unwrap(),
1150 options: Default::default(),
1151 }]
1152 .into();
1153 let exec =
1154 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1155 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1156
1157 let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1158 .await
1159 .unwrap();
1160 assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1161 +----+---+
1162 | a | b |
1163 +----+---+
1164 | 1 | a |
1165 | 10 | b |
1166 | 2 | c |
1167 | 20 | d |
1168 +----+---+
1169 "#);
1170
1171 let metrics = merge.metrics().unwrap();
1173
1174 assert_eq!(metrics.output_rows().unwrap(), 4);
1175 assert!(metrics.elapsed_compute().unwrap() > 0);
1176
1177 let mut saw_start = false;
1178 let mut saw_end = false;
1179 metrics.iter().for_each(|m| match m.value() {
1180 MetricValue::StartTimestamp(ts) => {
1181 saw_start = true;
1182 assert!(nanos_from_timestamp(ts) > 0);
1183 }
1184 MetricValue::EndTimestamp(ts) => {
1185 saw_end = true;
1186 assert!(nanos_from_timestamp(ts) > 0);
1187 }
1188 _ => {}
1189 });
1190
1191 assert!(saw_start);
1192 assert!(saw_end);
1193 }
1194
1195 fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1196 ts.value().unwrap().timestamp_nanos_opt().unwrap()
1197 }
1198
1199 #[tokio::test]
1200 async fn test_drop_cancel() -> Result<()> {
1201 let task_ctx = Arc::new(TaskContext::default());
1202 let schema =
1203 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1204
1205 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1206 let refs = blocking_exec.refs();
1207 let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1208 [PhysicalSortExpr {
1209 expr: col("a", &schema)?,
1210 options: SortOptions::default(),
1211 }]
1212 .into(),
1213 blocking_exec,
1214 ));
1215
1216 let fut = collect(sort_preserving_merge_exec, task_ctx);
1217 let mut fut = fut.boxed();
1218
1219 assert_is_pending(&mut fut);
1220 drop(fut);
1221 assert_strong_count_converges_to_zero(refs).await;
1222
1223 Ok(())
1224 }
1225
1226 #[tokio::test]
1227 async fn test_stable_sort() {
1228 let task_ctx = Arc::new(TaskContext::default());
1229
1230 let partitions: Vec<Vec<RecordBatch>> = (0..10)
1238 .map(|batch_number| {
1239 let batch_number: Int32Array =
1240 vec![Some(batch_number), Some(batch_number)]
1241 .into_iter()
1242 .collect();
1243 let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1244
1245 let batch = RecordBatch::try_from_iter(vec![
1246 ("batch_number", Arc::new(batch_number) as ArrayRef),
1247 ("value", Arc::new(value) as ArrayRef),
1248 ])
1249 .unwrap();
1250
1251 vec![batch]
1252 })
1253 .collect();
1254
1255 let schema = partitions[0][0].schema();
1256
1257 let sort = [PhysicalSortExpr {
1258 expr: col("value", &schema).unwrap(),
1259 options: SortOptions {
1260 descending: false,
1261 nulls_first: true,
1262 },
1263 }]
1264 .into();
1265
1266 let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1267 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1268
1269 let collected = collect(merge, task_ctx).await.unwrap();
1270 assert_eq!(collected.len(), 1);
1271
1272 assert_snapshot!(batches_to_string(collected.as_slice()), @r#"
1276 +--------------+-------+
1277 | batch_number | value |
1278 +--------------+-------+
1279 | 0 | A |
1280 | 1 | A |
1281 | 2 | A |
1282 | 3 | A |
1283 | 4 | A |
1284 | 5 | A |
1285 | 6 | A |
1286 | 7 | A |
1287 | 8 | A |
1288 | 9 | A |
1289 | 0 | B |
1290 | 1 | B |
1291 | 2 | B |
1292 | 3 | B |
1293 | 4 | B |
1294 | 5 | B |
1295 | 6 | B |
1296 | 7 | B |
1297 | 8 | B |
1298 | 9 | B |
1299 +--------------+-------+
1300 "#);
1301 }
1302
1303 #[derive(Debug)]
1304 struct CongestionState {
1305 wakers: Vec<Waker>,
1306 unpolled_partitions: HashSet<usize>,
1307 }
1308
1309 #[derive(Debug)]
1310 struct Congestion {
1311 congestion_state: Mutex<CongestionState>,
1312 }
1313
1314 impl Congestion {
1315 fn new(partition_count: usize) -> Self {
1316 Congestion {
1317 congestion_state: Mutex::new(CongestionState {
1318 wakers: vec![],
1319 unpolled_partitions: (0usize..partition_count).collect(),
1320 }),
1321 }
1322 }
1323
1324 fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1325 let mut state = self.congestion_state.lock().unwrap();
1326
1327 state.unpolled_partitions.remove(&partition);
1328
1329 if state.unpolled_partitions.is_empty() {
1330 state.wakers.iter().for_each(|w| w.wake_by_ref());
1331 state.wakers.clear();
1332 Poll::Ready(())
1333 } else {
1334 state.wakers.push(cx.waker().clone());
1335 Poll::Pending
1336 }
1337 }
1338 }
1339
1340 #[derive(Debug, Clone)]
1343 struct CongestedExec {
1344 schema: Schema,
1345 cache: PlanProperties,
1346 congestion: Arc<Congestion>,
1347 }
1348
1349 impl CongestedExec {
1350 fn compute_properties(schema: SchemaRef) -> PlanProperties {
1351 let columns = schema
1352 .fields
1353 .iter()
1354 .enumerate()
1355 .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1356 .collect::<Vec<_>>();
1357 let mut eq_properties = EquivalenceProperties::new(schema);
1358 eq_properties.add_ordering(
1359 columns
1360 .iter()
1361 .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1362 );
1363 PlanProperties::new(
1364 eq_properties,
1365 Partitioning::Hash(columns, 3),
1366 EmissionType::Incremental,
1367 Boundedness::Unbounded {
1368 requires_infinite_memory: false,
1369 },
1370 )
1371 }
1372 }
1373
1374 impl ExecutionPlan for CongestedExec {
1375 fn name(&self) -> &'static str {
1376 Self::static_name()
1377 }
1378 fn as_any(&self) -> &dyn Any {
1379 self
1380 }
1381 fn properties(&self) -> &PlanProperties {
1382 &self.cache
1383 }
1384 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1385 vec![]
1386 }
1387 fn with_new_children(
1388 self: Arc<Self>,
1389 _: Vec<Arc<dyn ExecutionPlan>>,
1390 ) -> Result<Arc<dyn ExecutionPlan>> {
1391 Ok(self)
1392 }
1393 fn execute(
1394 &self,
1395 partition: usize,
1396 _context: Arc<TaskContext>,
1397 ) -> Result<SendableRecordBatchStream> {
1398 Ok(Box::pin(CongestedStream {
1399 schema: Arc::new(self.schema.clone()),
1400 none_polled_once: false,
1401 congestion: Arc::clone(&self.congestion),
1402 partition,
1403 }))
1404 }
1405 }
1406
1407 impl DisplayAs for CongestedExec {
1408 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1409 match t {
1410 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1411 write!(f, "CongestedExec",).unwrap()
1412 }
1413 DisplayFormatType::TreeRender => {
1414 write!(f, "").unwrap()
1416 }
1417 }
1418 Ok(())
1419 }
1420 }
1421
1422 #[derive(Debug)]
1425 pub struct CongestedStream {
1426 schema: SchemaRef,
1427 none_polled_once: bool,
1428 congestion: Arc<Congestion>,
1429 partition: usize,
1430 }
1431
1432 impl Stream for CongestedStream {
1433 type Item = Result<RecordBatch>;
1434 fn poll_next(
1435 mut self: Pin<&mut Self>,
1436 cx: &mut Context<'_>,
1437 ) -> Poll<Option<Self::Item>> {
1438 match self.partition {
1439 0 => {
1440 let _ = self.congestion.check_congested(self.partition, cx);
1441 if self.none_polled_once {
1442 panic!("Exhausted stream is polled more than once")
1443 } else {
1444 self.none_polled_once = true;
1445 Poll::Ready(None)
1446 }
1447 }
1448 _ => {
1449 ready!(self.congestion.check_congested(self.partition, cx));
1450 Poll::Ready(None)
1451 }
1452 }
1453 }
1454 }
1455
1456 impl RecordBatchStream for CongestedStream {
1457 fn schema(&self) -> SchemaRef {
1458 Arc::clone(&self.schema)
1459 }
1460 }
1461
1462 #[tokio::test]
1463 async fn test_spm_congestion() -> Result<()> {
1464 let task_ctx = Arc::new(TaskContext::default());
1465 let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1466 let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1467 let &partition_count = match properties.output_partitioning() {
1468 Partitioning::RoundRobinBatch(partitions) => partitions,
1469 Partitioning::Hash(_, partitions) => partitions,
1470 Partitioning::UnknownPartitioning(partitions) => partitions,
1471 };
1472 let source = CongestedExec {
1473 schema: schema.clone(),
1474 cache: properties,
1475 congestion: Arc::new(Congestion::new(partition_count)),
1476 };
1477 let spm = SortPreservingMergeExec::new(
1478 [PhysicalSortExpr::new_default(Arc::new(Column::new(
1479 "c1", 0,
1480 )))]
1481 .into(),
1482 Arc::new(source),
1483 );
1484 let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1485
1486 let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1487 match result {
1488 Ok(Ok(Ok(_batches))) => Ok(()),
1489 Ok(Ok(Err(e))) => Err(e),
1490 Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1491 Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1492 }
1493 }
1494}